diff --git a/src/utils/draft_downloader.py b/src/utils/draft_downloader.py index 59a7eee..33f8de4 100644 --- a/src/utils/draft_downloader.py +++ b/src/utils/draft_downloader.py @@ -19,14 +19,76 @@ _REQUEST_CONNECT_TIMEOUT = 10 _REQUEST_READ_TIMEOUT = 30 _MAX_RETRIES = 5 -# 网关/上游暂不可用,适合退避重试(含用户关心的 503) -_RETRYABLE_GATEWAY_HTTP_STATUSES = frozenset({502, 503, 504}) +# 网关/限流等暂时不可用,退避重试有效(与 desktop-client 一致;不含 500 等持久故障) +_RETRYABLE_TRANSIENT_HTTP_STATUSES = frozenset({408, 429, 502, 503, 504}) +_TRANSIENT_HTTP_BACKOFF_MAX_SECONDS = 30 +_DEFAULT_NETWORK_RETRY_DELAY_SECONDS = 1.0 + +_NON_RETRYABLE_NETWORK_MARKERS = ( + "name or service not known", + "getaddrinfo failed", + "nodename nor servname provided", + "connection refused", + "failed to establish a new connection", +) -def _sleep_gateway_backoff(retry_no: int) -> None: - """网关类错误退避:1s 起指数增长,上限 32s。retry_no 为从 1 开始的本次重试序号。""" - delay = min(2 ** (retry_no - 1), 32) - time.sleep(delay) +def _is_retryable_http_status(status_code: int) -> bool: + return status_code in _RETRYABLE_TRANSIENT_HTTP_STATUSES + + +def _is_retryable_request_exception(exc: requests.exceptions.RequestException) -> bool: + if isinstance( + exc, + (requests.exceptions.Timeout, requests.exceptions.ChunkedEncodingError), + ): + return True + if isinstance(exc, requests.exceptions.ConnectionError): + msg = str(exc).lower() + if any(marker in msg for marker in _NON_RETRYABLE_NETWORK_MARKERS): + return False + return True + return False + + +def _parse_retry_after_seconds(headers) -> Optional[float]: + raw = headers.get("Retry-After") or headers.get("retry-after") + if raw is None or raw == "": + return None + try: + seconds = int(raw) + if seconds >= 0: + return min(float(seconds), _TRANSIENT_HTTP_BACKOFF_MAX_SECONDS) + except (TypeError, ValueError): + pass + try: + from email.utils import parsedate_to_datetime + from datetime import datetime, timezone + + retry_at = parsedate_to_datetime(str(raw)) + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + delta = (retry_at - datetime.now(timezone.utc)).total_seconds() + return min(max(0.0, delta), _TRANSIENT_HTTP_BACKOFF_MAX_SECONDS) + except (TypeError, ValueError, OverflowError): + return None + return None + + +def _sleep_transient_http_backoff( + retry_no: int, response: Optional[requests.Response] = None +) -> None: + """限流/网关错误退避:优先 Retry-After,否则 1s 起指数增长,上限 30s。""" + if response is not None: + delay = _parse_retry_after_seconds(response.headers) + if delay is not None: + time.sleep(delay) + return + time.sleep(min(2 ** (retry_no - 1), _TRANSIENT_HTTP_BACKOFF_MAX_SECONDS)) + + +def _sleep_network_retry_backoff() -> None: + time.sleep(_DEFAULT_NETWORK_RETRY_DELAY_SECONDS) def safe_write_file(file_path: str, file_content: bytes, is_binary: bool = True): @@ -152,18 +214,18 @@ def get_draft_files_list(draft_url: str) -> list: if response.status_code != 200: if ( - response.status_code in _RETRYABLE_GATEWAY_HTTP_STATUSES + _is_retryable_http_status(response.status_code) and attempt < _MAX_RETRIES ): retry_no = attempt + 1 logger.warning( - "Gateway HTTP %s while fetching draft file list, retry (%s/%s)", + "Transient HTTP %s while fetching draft file list, retry (%s/%s)", response.status_code, retry_no, _MAX_RETRIES, ) response.close() - _sleep_gateway_backoff(retry_no) + _sleep_transient_http_backoff(retry_no, response) continue logger.error( f"Failed to get draft file list, HTTP status: {response.status_code}" @@ -190,18 +252,22 @@ def get_draft_files_list(draft_url: str) -> list: logger.info(f"Fetched {len(files)} draft file(s)") return files except requests.exceptions.RequestException as e: - if attempt >= _MAX_RETRIES: - logger.error( - f"Network error while fetching draft file list after {_MAX_RETRIES} retries: {e}" - ) + if not _is_retryable_request_exception(e) or attempt >= _MAX_RETRIES: + if attempt >= _MAX_RETRIES: + logger.error( + f"Network error while fetching draft file list after {_MAX_RETRIES} retries: {e}" + ) + else: + logger.error( + f"Network error while fetching draft file list is not retryable: {e}" + ) return [] retry_no = attempt + 1 - backoff_seconds = retry_no logger.warning( f"Network error while fetching draft file list, retry ({retry_no}/{_MAX_RETRIES}): {e}" ) - time.sleep(backoff_seconds) + _sleep_network_retry_backoff() except Exception as e: logger.error(f"Unexpected error while fetching draft file list: {e}") return [] @@ -281,29 +347,31 @@ def download_single_file(file_url: str, target_dir: str) -> bool: ) try: if response.status_code != 200: - if response.status_code in _RETRYABLE_GATEWAY_HTTP_STATUSES: - retry_count += 1 - if retry_count > max_retries: - logger.error( - "Gateway HTTP %s, download failed after %s retries, URL: %s", - response.status_code, - max_retries, - file_url, - ) - return False - logger.warning( - "Gateway HTTP %s, retry (%s/%s), URL: %s", + if not _is_retryable_http_status(response.status_code): + logger.error( + "Download failed (HTTP %s, not retryable), URL: %s", + response.status_code, + file_url, + ) + return False + retry_count += 1 + if retry_count > max_retries: + logger.error( + "Transient HTTP %s, download failed after %s retries, URL: %s", response.status_code, - retry_count, max_retries, file_url, ) - _sleep_gateway_backoff(retry_count) - continue - logger.error( - f"Download failed, HTTP status: {response.status_code}, URL: {file_url}" + return False + logger.warning( + "Transient HTTP %s, retry (%s/%s), URL: %s", + response.status_code, + retry_count, + max_retries, + file_url, ) - return False + _sleep_transient_http_backoff(retry_count, response) + continue parent_dir = os.path.dirname(full_file_path) if parent_dir: @@ -324,17 +392,21 @@ def download_single_file(file_url: str, target_dir: str) -> bool: return True except requests.exceptions.RequestException as e: + if not _is_retryable_request_exception(e): + logger.error( + f"Network error is not retryable: {e}, URL: {file_url}" + ) + return False retry_count += 1 if retry_count > max_retries: logger.error( f"Network error, download failed after {max_retries} retries: {e}, URL: {file_url}" ) return False - else: - logger.warning( - f"Network error, retry ({retry_count}/{max_retries}): {e}, URL: {file_url}" - ) - time.sleep(1 * retry_count) # 递增延迟 + logger.warning( + f"Network error, retry ({retry_count}/{max_retries}): {e}, URL: {file_url}" + ) + _sleep_network_retry_backoff() except OSError as e: logger.error(f"File write error, download failed: {e}, URL: {file_url}") return False @@ -534,6 +606,12 @@ def _download_remote_material( stream=True, ) if response.status_code != 200: + if not _is_retryable_http_status(response.status_code): + logger.error( + f"Remote material download failed (HTTP {response.status_code}), " + f"not retryable: {file_url}" + ) + return None if attempt >= _MAX_RETRIES: logger.error( f"Remote material download failed (HTTP {response.status_code}) " @@ -542,12 +620,10 @@ def _download_remote_material( return None retry_no = attempt + 1 logger.warning( - f"Remote material download non-200, retry ({retry_no}/{_MAX_RETRIES}): {file_url}" + f"Remote material download transient HTTP {response.status_code}, " + f"retry ({retry_no}/{_MAX_RETRIES}): {file_url}" ) - if response.status_code in _RETRYABLE_GATEWAY_HTTP_STATUSES: - _sleep_gateway_backoff(retry_no) - else: - time.sleep(retry_no) + _sleep_transient_http_backoff(retry_no, response) response.close() continue @@ -566,6 +642,11 @@ def _download_remote_material( f.write(chunk) return local_path except requests.exceptions.RequestException as e: + if not _is_retryable_request_exception(e): + logger.error( + f"Remote material download failed, not retryable: {file_url}, error: {e}" + ) + return None if attempt >= _MAX_RETRIES: logger.error( f"Remote material download failed after {_MAX_RETRIES} retries: {file_url}, error: {e}" @@ -576,7 +657,7 @@ def _download_remote_material( f"Remote material download network error, retry ({retry_no}/{_MAX_RETRIES}): " f"{file_url}, {e}" ) - time.sleep(retry_no) + _sleep_network_retry_backoff() except OSError as e: logger.error(f"Failed to write remote material to disk: {file_url}, {e}") return None @@ -596,6 +677,12 @@ def _download_remote_file(file_url: str, local_path: str) -> bool: stream=True, ) if response.status_code != 200: + if not _is_retryable_http_status(response.status_code): + logger.error( + f"Remote material download failed (HTTP {response.status_code}), " + f"not retryable: {file_url}" + ) + return False if attempt >= _MAX_RETRIES: logger.error( f"Remote material download failed (HTTP {response.status_code}) " @@ -604,12 +691,10 @@ def _download_remote_file(file_url: str, local_path: str) -> bool: return False retry_no = attempt + 1 logger.warning( - f"Remote material download non-200, retry ({retry_no}/{_MAX_RETRIES}): {file_url}" + f"Remote material download transient HTTP {response.status_code}, " + f"retry ({retry_no}/{_MAX_RETRIES}): {file_url}" ) - if response.status_code in _RETRYABLE_GATEWAY_HTTP_STATUSES: - _sleep_gateway_backoff(retry_no) - else: - time.sleep(retry_no) + _sleep_transient_http_backoff(retry_no, response) response.close() continue @@ -622,6 +707,11 @@ def _download_remote_file(file_url: str, local_path: str) -> bool: f.write(chunk) return True except requests.exceptions.RequestException as e: + if not _is_retryable_request_exception(e): + logger.error( + f"Remote material download failed, not retryable: {file_url}, error: {e}" + ) + return False if attempt >= _MAX_RETRIES: logger.error( f"Remote material download failed after {_MAX_RETRIES} retries: {file_url}, error: {e}" @@ -632,7 +722,7 @@ def _download_remote_file(file_url: str, local_path: str) -> bool: f"Remote material download network error, retry ({retry_no}/{_MAX_RETRIES}): " f"{file_url}, {e}" ) - time.sleep(retry_no) + _sleep_network_retry_backoff() except OSError as e: logger.error(f"Failed to write remote material to disk: {local_path}, {e}") return False diff --git a/tests/test_draft_downloader_remote_materials.py b/tests/test_draft_downloader_remote_materials.py index d4bfd0f..670f5ea 100644 --- a/tests/test_draft_downloader_remote_materials.py +++ b/tests/test_draft_downloader_remote_materials.py @@ -82,6 +82,45 @@ class TestDownloadRemoteMaterial: assert local_path is not None assert local_path.endswith(".mp4") + def test_non200_404_does_not_retry(self, no_sleep) -> None: + bad = MagicMock() + bad.status_code = 404 + bad.close = MagicMock() + with tempfile.TemporaryDirectory() as td: + with patch.object(dd, "requests") as m_req: + m_req.get.return_value = bad + m_req.exceptions = requests.exceptions + assert dd._download_remote_material( + "https://cdn.example.com/miss.png", td, "images", "x", ".png" + ) is None + assert m_req.get.call_count == 1 + + +class TestRetryHelpers: + @pytest.mark.parametrize( + "status, expected", + [ + (404, False), + (400, False), + (500, False), + (408, True), + (429, True), + (502, True), + (503, True), + (504, True), + ], + ) + def test_is_retryable_http_status(self, status: int, expected: bool) -> None: + assert dd._is_retryable_http_status(status) is expected + + def test_connection_refused_not_retryable(self) -> None: + exc = requests.exceptions.ConnectionError("Connection refused") + assert dd._is_retryable_request_exception(exc) is False + + def test_read_timeout_retryable(self) -> None: + exc = requests.exceptions.ReadTimeout("read timed out") + assert dd._is_retryable_request_exception(exc) is True + class TestDownloadRemoteFile: def _ok_response(self, content: bytes = b"data") -> MagicMock: @@ -124,14 +163,25 @@ class TestDownloadRemoteFile: if os.path.isfile(out): os.remove(out) - def test_returns_false_after_exhausting_retries(self, no_sleep) -> None: - with patch.object(dd, "_MAX_RETRIES", 2): - with patch.object(dd, "requests") as m_req: - m_req.get.side_effect = requests.exceptions.ConnectionError("refused") - m_req.exceptions = requests.exceptions - out = os.path.join(tempfile.gettempdir(), "t_dl_fail.bin") - assert dd._download_remote_file("https://x.test/c.mp4", out) is False - assert m_req.get.call_count == 3 + def test_returns_false_immediately_on_connection_refused(self, no_sleep) -> None: + with patch.object(dd, "requests") as m_req: + m_req.get.side_effect = requests.exceptions.ConnectionError("Connection refused") + m_req.exceptions = requests.exceptions + out = os.path.join(tempfile.gettempdir(), "t_dl_fail.bin") + assert dd._download_remote_file("https://x.test/c.mp4", out) is False + assert m_req.get.call_count == 1 + + def test_non200_404_does_not_retry(self, no_sleep) -> None: + bad = MagicMock() + bad.status_code = 404 + bad.close = MagicMock() + out = os.path.join(tempfile.gettempdir(), "t_dl_404.bin") + with patch.object(dd, "requests") as m_req: + m_req.get.return_value = bad + m_req.exceptions = requests.exceptions + assert dd._download_remote_file("https://x.test/miss.mp4", out) is False + assert m_req.get.call_count == 1 + bad.close.assert_not_called() def test_non200_retries_then_success(self, no_sleep) -> None: bad = MagicMock() @@ -388,6 +438,19 @@ class TestDownloadSingleFile: with open(out, "rb") as f: assert f.read() == b"ok" + def test_connection_refused_does_not_retry(self, no_sleep) -> None: + file_url = ( + f"{self._BASE}/app/output/draft/20251204214904ccb1af38/miss.bin" + ) + with tempfile.TemporaryDirectory() as td: + with patch.object(dd, "requests") as m_req: + m_req.get.side_effect = requests.exceptions.ConnectionError( + "Connection refused" + ) + m_req.exceptions = requests.exceptions + assert dd.download_single_file(file_url, td) is False + assert m_req.get.call_count == 1 + def test_returns_false_after_exhausting_retries(self, no_sleep) -> None: file_url = ( f"{self._BASE}/app/output/draft/20251204214904ccb1af38/miss.bin"