mirror of
https://github.com/TheSmallHanCat/flow2api.git
synced 2026-06-02 12:51:38 +08:00
feat: 优化远程打码并发调度与观测能力
This commit is contained in:
162
src/api/admin.py
162
src/api/admin.py
@@ -1,5 +1,8 @@
|
||||
"""Admin API routes"""
|
||||
import asyncio
|
||||
import json
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +10,7 @@ from typing import Optional, List, Dict, Any
|
||||
import secrets
|
||||
import time
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
from curl_cffi.requests import AsyncSession
|
||||
from ..core.auth import AuthManager
|
||||
from ..core.database import Database
|
||||
@@ -98,6 +102,70 @@ def _build_proxy_map(proxy_url: str) -> Optional[Dict[str, str]]:
|
||||
return {"http": normalized, "https": normalized}
|
||||
|
||||
|
||||
def _normalize_http_base_url(base_url: str) -> str:
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
raise RuntimeError("远程打码服务地址未配置")
|
||||
|
||||
parsed = urlparse(normalized)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise RuntimeError("远程打码服务地址格式错误,必须是 http(s)://host[:port]")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _get_remote_browser_client_config() -> tuple[str, str, int]:
|
||||
base_url = _normalize_http_base_url(config.remote_browser_base_url)
|
||||
api_key = (config.remote_browser_api_key or "").strip()
|
||||
if not api_key:
|
||||
raise RuntimeError("远程打码服务 API Key 未配置")
|
||||
timeout = max(5, int(config.remote_browser_timeout or 60))
|
||||
return base_url, api_key, timeout
|
||||
|
||||
|
||||
def _sync_json_http_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
payload: Optional[Dict[str, Any]],
|
||||
timeout: int,
|
||||
) -> tuple[int, Optional[Any], str]:
|
||||
req_headers = dict(headers or {})
|
||||
req_headers.setdefault("Accept", "application/json")
|
||||
|
||||
data = None
|
||||
if payload is not None:
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req_headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
|
||||
request = urllib.request.Request(
|
||||
url=url,
|
||||
data=data,
|
||||
headers=req_headers,
|
||||
method=(method or "GET").upper(),
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
status_code = int(response.getcode() or 0)
|
||||
raw_body = response.read()
|
||||
except urllib.error.HTTPError as e:
|
||||
status_code = int(getattr(e, "code", 500))
|
||||
raw_body = e.read() if hasattr(e, "read") else b""
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"远程打码服务请求失败: {e}") from e
|
||||
|
||||
text = raw_body.decode("utf-8", errors="replace") if raw_body else ""
|
||||
parsed: Optional[Any] = None
|
||||
if text:
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except Exception:
|
||||
parsed = None
|
||||
|
||||
return status_code, parsed, text
|
||||
|
||||
|
||||
async def _resolve_score_test_verify_proxy(
|
||||
captcha_method: str,
|
||||
browser_proxy_enabled: bool,
|
||||
@@ -208,6 +276,46 @@ async def _solve_recaptcha_with_api_service(
|
||||
raise RuntimeError(f"{method} 获取 token 超时")
|
||||
|
||||
|
||||
async def _score_test_with_remote_browser_service(
|
||||
website_url: str,
|
||||
website_key: str,
|
||||
verify_url: str,
|
||||
action: str,
|
||||
enterprise: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""调用远程有头打码服务执行页面内打码+分数校验。"""
|
||||
base_url, api_key, timeout = _get_remote_browser_client_config()
|
||||
endpoint = f"{base_url}/api/v1/custom-score"
|
||||
request_payload = {
|
||||
"website_url": website_url,
|
||||
"website_key": website_key,
|
||||
"verify_url": verify_url,
|
||||
"action": action,
|
||||
"enterprise": enterprise,
|
||||
}
|
||||
|
||||
status_code, response_payload, response_text = await asyncio.to_thread(
|
||||
_sync_json_http_request,
|
||||
"POST",
|
||||
endpoint,
|
||||
{"Authorization": f"Bearer {api_key}"},
|
||||
request_payload,
|
||||
timeout,
|
||||
)
|
||||
|
||||
if status_code >= 400:
|
||||
detail = ""
|
||||
if isinstance(response_payload, dict):
|
||||
detail = response_payload.get("detail") or response_payload.get("message") or str(response_payload)
|
||||
if not detail:
|
||||
detail = (response_text or "").strip()
|
||||
raise RuntimeError(f"远程打码服务请求失败 (HTTP {status_code}): {detail or '未知错误'}")
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
raise RuntimeError("远程打码服务返回格式错误")
|
||||
return response_payload
|
||||
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, cm: Optional[ConcurrencyManager] = None):
|
||||
"""Set service instances"""
|
||||
global token_manager, proxy_manager, db, concurrency_manager
|
||||
@@ -1196,6 +1304,9 @@ async def update_captcha_config(
|
||||
ezcaptcha_base_url = request.get("ezcaptcha_base_url")
|
||||
capsolver_api_key = request.get("capsolver_api_key")
|
||||
capsolver_base_url = request.get("capsolver_base_url")
|
||||
remote_browser_base_url = request.get("remote_browser_base_url")
|
||||
remote_browser_api_key = request.get("remote_browser_api_key")
|
||||
remote_browser_timeout = request.get("remote_browser_timeout", 60)
|
||||
browser_proxy_enabled = request.get("browser_proxy_enabled", False)
|
||||
browser_proxy_url = request.get("browser_proxy_url", "")
|
||||
browser_count = request.get("browser_count", 1)
|
||||
@@ -1206,6 +1317,23 @@ async def update_captcha_config(
|
||||
if not is_valid:
|
||||
return {"success": False, "message": error_msg}
|
||||
|
||||
if remote_browser_base_url:
|
||||
try:
|
||||
remote_browser_base_url = _normalize_http_base_url(remote_browser_base_url)
|
||||
except RuntimeError as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
try:
|
||||
remote_browser_timeout = max(5, int(remote_browser_timeout or 60))
|
||||
except Exception:
|
||||
return {"success": False, "message": "远程打码超时时间必须是整数秒"}
|
||||
|
||||
if captcha_method == "remote_browser":
|
||||
if not (remote_browser_base_url or "").strip():
|
||||
return {"success": False, "message": "remote_browser 模式需要配置远程打码服务地址"}
|
||||
if not (remote_browser_api_key or "").strip():
|
||||
return {"success": False, "message": "remote_browser 模式需要配置远程打码服务 API Key"}
|
||||
|
||||
await db.update_captcha_config(
|
||||
captcha_method=captcha_method,
|
||||
yescaptcha_api_key=yescaptcha_api_key,
|
||||
@@ -1216,6 +1344,9 @@ async def update_captcha_config(
|
||||
ezcaptcha_base_url=ezcaptcha_base_url,
|
||||
capsolver_api_key=capsolver_api_key,
|
||||
capsolver_base_url=capsolver_base_url,
|
||||
remote_browser_base_url=remote_browser_base_url,
|
||||
remote_browser_api_key=remote_browser_api_key,
|
||||
remote_browser_timeout=remote_browser_timeout,
|
||||
browser_proxy_enabled=browser_proxy_enabled,
|
||||
browser_proxy_url=browser_proxy_url if browser_proxy_enabled else None,
|
||||
browser_count=max(1, int(browser_count)) if browser_count else 1
|
||||
@@ -1250,6 +1381,9 @@ async def get_captcha_config(token: str = Depends(verify_admin_token)):
|
||||
"ezcaptcha_base_url": captcha_config.ezcaptcha_base_url,
|
||||
"capsolver_api_key": captcha_config.capsolver_api_key,
|
||||
"capsolver_base_url": captcha_config.capsolver_base_url,
|
||||
"remote_browser_base_url": captcha_config.remote_browser_base_url,
|
||||
"remote_browser_api_key": captcha_config.remote_browser_api_key,
|
||||
"remote_browser_timeout": captcha_config.remote_browser_timeout,
|
||||
"browser_proxy_enabled": captcha_config.browser_proxy_enabled,
|
||||
"browser_proxy_url": captcha_config.browser_proxy_url or "",
|
||||
"browser_count": captcha_config.browser_count
|
||||
@@ -1286,7 +1420,7 @@ async def test_captcha_score(
|
||||
verify_proxy_source = "none"
|
||||
verify_proxy_url = ""
|
||||
verify_impersonate = "chrome120"
|
||||
page_verify_only = captcha_method in {"browser", "personal"}
|
||||
page_verify_only = captcha_method in {"browser", "personal", "remote_browser"}
|
||||
verify_mode = "browser_page" if page_verify_only else "server_post"
|
||||
|
||||
try:
|
||||
@@ -1339,6 +1473,26 @@ async def test_captcha_score(
|
||||
verify_proxy_used = bool(browser_proxy_enabled and browser_proxy_url)
|
||||
verify_proxy_source = "captcha_browser_proxy" if verify_proxy_used else "browser_direct"
|
||||
verify_proxy_url = browser_proxy_url if verify_proxy_used else ""
|
||||
elif captcha_method == "remote_browser":
|
||||
score_payload = await _score_test_with_remote_browser_service(
|
||||
website_url=website_url,
|
||||
website_key=website_key,
|
||||
verify_url=verify_url,
|
||||
action=action,
|
||||
enterprise=enterprise,
|
||||
)
|
||||
if isinstance(score_payload, dict):
|
||||
if score_payload.get("success") is False:
|
||||
raise RuntimeError(score_payload.get("message") or "远程打码分数测试失败")
|
||||
token_value = score_payload.get("token")
|
||||
verify_elapsed_ms = int(score_payload.get("verify_elapsed_ms") or 0)
|
||||
verify_http_status = score_payload.get("verify_http_status")
|
||||
verify_result = score_payload.get("verify_result") if isinstance(score_payload.get("verify_result"), dict) else {}
|
||||
verify_mode = score_payload.get("verify_mode") or "remote_browser_page"
|
||||
score_token_elapsed = score_payload.get("token_elapsed_ms")
|
||||
if isinstance(score_token_elapsed, (int, float)):
|
||||
token_elapsed_ms = int(score_token_elapsed)
|
||||
fingerprint = score_payload.get("fingerprint") if isinstance(score_payload.get("fingerprint"), dict) else None
|
||||
elif captcha_method in SUPPORTED_API_CAPTCHA_METHODS:
|
||||
token_value = await _solve_recaptcha_with_api_service(
|
||||
method=captcha_method,
|
||||
@@ -1363,6 +1517,12 @@ async def test_captcha_score(
|
||||
if token_elapsed_ms <= 0:
|
||||
token_elapsed_ms = int((time.time() - token_start) * 1000)
|
||||
|
||||
# 远程有头打码的 custom-score 可能由页面内直接完成校验,
|
||||
# 在部分实现里不会显式回传 token,本地按 verify_result 兜底判定。
|
||||
if captcha_method == "remote_browser" and not token_value and isinstance(verify_result, dict):
|
||||
if verify_result.get("success") is True:
|
||||
token_value = verify_result.get("token") or verify_result.get("gRecaptchaResponse") or "__verified_by_remote__"
|
||||
|
||||
if not token_value:
|
||||
return {
|
||||
"success": False,
|
||||
|
||||
@@ -54,11 +54,142 @@ class Config:
|
||||
|
||||
@property
|
||||
def flow_timeout(self) -> int:
|
||||
return self._config["flow"]["timeout"]
|
||||
timeout = self._config.get("flow", {}).get("timeout", 120)
|
||||
try:
|
||||
return max(5, int(timeout))
|
||||
except Exception:
|
||||
return 120
|
||||
|
||||
@property
|
||||
def flow_max_retries(self) -> int:
|
||||
return self._config["flow"]["max_retries"]
|
||||
retries = self._config.get("flow", {}).get("max_retries", 3)
|
||||
try:
|
||||
return max(1, int(retries))
|
||||
except Exception:
|
||||
return 3
|
||||
|
||||
@property
|
||||
def flow_image_request_timeout(self) -> int:
|
||||
"""图片生成单次 HTTP 请求超时(秒)。"""
|
||||
default_timeout = min(self.flow_timeout, 40)
|
||||
timeout = self._config.get("flow", {}).get(
|
||||
"image_request_timeout",
|
||||
default_timeout
|
||||
)
|
||||
try:
|
||||
return max(5, int(timeout))
|
||||
except Exception:
|
||||
return self.flow_timeout
|
||||
|
||||
@property
|
||||
def flow_image_timeout_retry_count(self) -> int:
|
||||
"""图片生成遇到网络超时时的快速重试次数。"""
|
||||
retry_count = self._config.get("flow", {}).get("image_timeout_retry_count", 1)
|
||||
try:
|
||||
return max(0, min(3, int(retry_count)))
|
||||
except Exception:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def flow_image_timeout_retry_delay(self) -> float:
|
||||
"""图片生成网络超时重试前等待秒数。"""
|
||||
delay = self._config.get("flow", {}).get("image_timeout_retry_delay", 0.8)
|
||||
try:
|
||||
return max(0.0, min(5.0, float(delay)))
|
||||
except Exception:
|
||||
return 0.8
|
||||
|
||||
@property
|
||||
def flow_image_timeout_use_media_proxy_fallback(self) -> bool:
|
||||
"""网络超时时是否切换媒体代理重试。"""
|
||||
return bool(
|
||||
self._config.get("flow", {}).get(
|
||||
"image_timeout_use_media_proxy_fallback",
|
||||
True
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def flow_image_prefer_media_proxy(self) -> bool:
|
||||
"""图片生成是否优先走媒体代理链路。"""
|
||||
return bool(
|
||||
self._config.get("flow", {}).get(
|
||||
"image_prefer_media_proxy",
|
||||
False
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def flow_image_slot_wait_timeout(self) -> float:
|
||||
"""图片硬并发槽位等待超时(秒)。"""
|
||||
timeout = self._config.get("flow", {}).get("image_slot_wait_timeout", 120)
|
||||
try:
|
||||
return max(1.0, min(600.0, float(timeout)))
|
||||
except Exception:
|
||||
return 120.0
|
||||
|
||||
@property
|
||||
def flow_image_launch_soft_limit(self) -> int:
|
||||
"""图片生成前置发车软并发上限(0 表示关闭软整形,仅使用硬并发)。"""
|
||||
value = self._config.get("flow", {}).get("image_launch_soft_limit", 0)
|
||||
try:
|
||||
return max(0, min(200, int(value)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def flow_image_launch_wait_timeout(self) -> float:
|
||||
"""图片前置发车软并发等待超时(秒)。"""
|
||||
timeout = self._config.get("flow", {}).get("image_launch_wait_timeout", 180)
|
||||
try:
|
||||
return max(1.0, min(600.0, float(timeout)))
|
||||
except Exception:
|
||||
return 180.0
|
||||
|
||||
@property
|
||||
def flow_image_launch_stagger_ms(self) -> int:
|
||||
"""图片请求前置发车间隔(毫秒),用于平滑同批突发。"""
|
||||
value = self._config.get("flow", {}).get("image_launch_stagger_ms", 0)
|
||||
try:
|
||||
return max(0, min(5000, int(value)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def flow_video_slot_wait_timeout(self) -> float:
|
||||
"""视频硬并发槽位等待超时(秒)。"""
|
||||
timeout = self._config.get("flow", {}).get("video_slot_wait_timeout", 120)
|
||||
try:
|
||||
return max(1.0, min(600.0, float(timeout)))
|
||||
except Exception:
|
||||
return 120.0
|
||||
|
||||
@property
|
||||
def flow_video_launch_soft_limit(self) -> int:
|
||||
"""视频生成前置发车软并发上限(0 表示关闭软整形,仅使用硬并发)。"""
|
||||
value = self._config.get("flow", {}).get("video_launch_soft_limit", 0)
|
||||
try:
|
||||
return max(0, min(200, int(value)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def flow_video_launch_wait_timeout(self) -> float:
|
||||
"""视频前置发车软并发等待超时(秒)。"""
|
||||
timeout = self._config.get("flow", {}).get("video_launch_wait_timeout", 180)
|
||||
try:
|
||||
return max(1.0, min(600.0, float(timeout)))
|
||||
except Exception:
|
||||
return 180.0
|
||||
|
||||
@property
|
||||
def flow_video_launch_stagger_ms(self) -> int:
|
||||
"""视频请求前置发车间隔(毫秒),用于平滑同批突发。"""
|
||||
value = self._config.get("flow", {}).get("video_launch_stagger_ms", 0)
|
||||
try:
|
||||
return max(0, min(5000, int(value)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def poll_interval(self) -> float:
|
||||
@@ -213,6 +344,15 @@ class Config:
|
||||
self._config["captcha"] = {}
|
||||
self._config["captcha"]["browser_launch_background"] = bool(enabled)
|
||||
|
||||
@property
|
||||
def browser_recaptcha_settle_seconds(self) -> float:
|
||||
"""有头打码在 reload/clr 就绪后的额外等待秒数。"""
|
||||
value = self._config.get("captcha", {}).get("browser_recaptcha_settle_seconds", 3.0)
|
||||
try:
|
||||
return max(0.0, min(10.0, float(value)))
|
||||
except Exception:
|
||||
return 3.0
|
||||
|
||||
@property
|
||||
def yescaptcha_api_key(self) -> str:
|
||||
"""Get YesCaptcha API key"""
|
||||
@@ -301,6 +441,47 @@ class Config:
|
||||
self._config["captcha"] = {}
|
||||
self._config["captcha"]["capsolver_base_url"] = base_url
|
||||
|
||||
@property
|
||||
def remote_browser_base_url(self) -> str:
|
||||
"""Get remote browser captcha service base URL"""
|
||||
return self._config.get("captcha", {}).get("remote_browser_base_url", "")
|
||||
|
||||
def set_remote_browser_base_url(self, base_url: str):
|
||||
"""Set remote browser captcha service base URL"""
|
||||
if "captcha" not in self._config:
|
||||
self._config["captcha"] = {}
|
||||
self._config["captcha"]["remote_browser_base_url"] = (base_url or "").strip()
|
||||
|
||||
@property
|
||||
def remote_browser_api_key(self) -> str:
|
||||
"""Get remote browser captcha service API key"""
|
||||
return self._config.get("captcha", {}).get("remote_browser_api_key", "")
|
||||
|
||||
def set_remote_browser_api_key(self, api_key: str):
|
||||
"""Set remote browser captcha service API key"""
|
||||
if "captcha" not in self._config:
|
||||
self._config["captcha"] = {}
|
||||
self._config["captcha"]["remote_browser_api_key"] = (api_key or "").strip()
|
||||
|
||||
@property
|
||||
def remote_browser_timeout(self) -> int:
|
||||
"""Get remote browser captcha request timeout (seconds)"""
|
||||
timeout = self._config.get("captcha", {}).get("remote_browser_timeout", 60)
|
||||
try:
|
||||
return max(5, int(timeout))
|
||||
except Exception:
|
||||
return 60
|
||||
|
||||
def set_remote_browser_timeout(self, timeout: int):
|
||||
"""Set remote browser captcha request timeout (seconds)"""
|
||||
if "captcha" not in self._config:
|
||||
self._config["captcha"] = {}
|
||||
try:
|
||||
normalized = max(5, int(timeout))
|
||||
except Exception:
|
||||
normalized = 60
|
||||
self._config["captcha"]["remote_browser_timeout"] = normalized
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
@@ -166,17 +166,37 @@ class Database:
|
||||
captcha_method = "browser"
|
||||
yescaptcha_api_key = ""
|
||||
yescaptcha_base_url = "https://api.yescaptcha.com"
|
||||
remote_browser_base_url = ""
|
||||
remote_browser_api_key = ""
|
||||
remote_browser_timeout = 60
|
||||
|
||||
if config_dict:
|
||||
captcha_config = config_dict.get("captcha", {})
|
||||
captcha_method = captcha_config.get("captcha_method", "browser")
|
||||
yescaptcha_api_key = captcha_config.get("yescaptcha_api_key", "")
|
||||
yescaptcha_base_url = captcha_config.get("yescaptcha_base_url", "https://api.yescaptcha.com")
|
||||
remote_browser_base_url = captcha_config.get("remote_browser_base_url", "")
|
||||
remote_browser_api_key = captcha_config.get("remote_browser_api_key", "")
|
||||
remote_browser_timeout = captcha_config.get("remote_browser_timeout", 60)
|
||||
try:
|
||||
remote_browser_timeout = max(5, int(remote_browser_timeout))
|
||||
except Exception:
|
||||
remote_browser_timeout = 60
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO captcha_config (id, captcha_method, yescaptcha_api_key, yescaptcha_base_url)
|
||||
VALUES (1, ?, ?, ?)
|
||||
""", (captcha_method, yescaptcha_api_key, yescaptcha_base_url))
|
||||
INSERT INTO captcha_config (
|
||||
id, captcha_method, yescaptcha_api_key, yescaptcha_base_url,
|
||||
remote_browser_base_url, remote_browser_api_key, remote_browser_timeout
|
||||
)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
captcha_method,
|
||||
yescaptcha_api_key,
|
||||
yescaptcha_base_url,
|
||||
remote_browser_base_url,
|
||||
remote_browser_api_key,
|
||||
remote_browser_timeout,
|
||||
))
|
||||
|
||||
# Ensure plugin_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM plugin_config")
|
||||
@@ -247,6 +267,9 @@ class Database:
|
||||
ezcaptcha_base_url TEXT DEFAULT 'https://api.ez-captcha.com',
|
||||
capsolver_api_key TEXT DEFAULT '',
|
||||
capsolver_base_url TEXT DEFAULT 'https://api.capsolver.com',
|
||||
remote_browser_base_url TEXT DEFAULT '',
|
||||
remote_browser_api_key TEXT DEFAULT '',
|
||||
remote_browser_timeout INTEGER DEFAULT 60,
|
||||
website_key TEXT DEFAULT '6LdsFiUsAAAAAIjVDZcuLhaHiDn5nnHVXVRQGeMV',
|
||||
page_action TEXT DEFAULT 'IMAGE_GENERATION',
|
||||
browser_proxy_enabled BOOLEAN DEFAULT 0,
|
||||
@@ -332,6 +355,9 @@ class Database:
|
||||
("capsolver_api_key", "TEXT DEFAULT ''"),
|
||||
("capsolver_base_url", "TEXT DEFAULT 'https://api.capsolver.com'"),
|
||||
("browser_count", "INTEGER DEFAULT 1"),
|
||||
("remote_browser_base_url", "TEXT DEFAULT ''"),
|
||||
("remote_browser_api_key", "TEXT DEFAULT ''"),
|
||||
("remote_browser_timeout", "INTEGER DEFAULT 60"),
|
||||
]
|
||||
|
||||
for col_name, col_type in captcha_columns_to_add:
|
||||
@@ -553,6 +579,9 @@ class Database:
|
||||
ezcaptcha_base_url TEXT DEFAULT 'https://api.ez-captcha.com',
|
||||
capsolver_api_key TEXT DEFAULT '',
|
||||
capsolver_base_url TEXT DEFAULT 'https://api.capsolver.com',
|
||||
remote_browser_base_url TEXT DEFAULT '',
|
||||
remote_browser_api_key TEXT DEFAULT '',
|
||||
remote_browser_timeout INTEGER DEFAULT 60,
|
||||
website_key TEXT DEFAULT '6LdsFiUsAAAAAIjVDZcuLhaHiDn5nnHVXVRQGeMV',
|
||||
page_action TEXT DEFAULT 'IMAGE_GENERATION',
|
||||
|
||||
@@ -1292,6 +1321,9 @@ class Database:
|
||||
config.set_ezcaptcha_base_url(captcha_config.ezcaptcha_base_url)
|
||||
config.set_capsolver_api_key(captcha_config.capsolver_api_key)
|
||||
config.set_capsolver_base_url(captcha_config.capsolver_base_url)
|
||||
config.set_remote_browser_base_url(captcha_config.remote_browser_base_url)
|
||||
config.set_remote_browser_api_key(captcha_config.remote_browser_api_key)
|
||||
config.set_remote_browser_timeout(captcha_config.remote_browser_timeout)
|
||||
|
||||
# Cache config operations
|
||||
async def get_cache_config(self) -> CacheConfig:
|
||||
@@ -1418,6 +1450,9 @@ class Database:
|
||||
ezcaptcha_base_url: str = None,
|
||||
capsolver_api_key: str = None,
|
||||
capsolver_base_url: str = None,
|
||||
remote_browser_base_url: str = None,
|
||||
remote_browser_api_key: str = None,
|
||||
remote_browser_timeout: int = None,
|
||||
browser_proxy_enabled: bool = None,
|
||||
browser_proxy_url: str = None,
|
||||
browser_count: int = None
|
||||
@@ -1439,9 +1474,13 @@ class Database:
|
||||
new_ez_url = ezcaptcha_base_url if ezcaptcha_base_url is not None else current.get("ezcaptcha_base_url", "https://api.ez-captcha.com")
|
||||
new_cs_key = capsolver_api_key if capsolver_api_key is not None else current.get("capsolver_api_key", "")
|
||||
new_cs_url = capsolver_base_url if capsolver_base_url is not None else current.get("capsolver_base_url", "https://api.capsolver.com")
|
||||
new_remote_base_url = remote_browser_base_url if remote_browser_base_url is not None else current.get("remote_browser_base_url", "")
|
||||
new_remote_api_key = remote_browser_api_key if remote_browser_api_key is not None else current.get("remote_browser_api_key", "")
|
||||
new_remote_timeout = remote_browser_timeout if remote_browser_timeout is not None else current.get("remote_browser_timeout", 60)
|
||||
new_proxy_enabled = browser_proxy_enabled if browser_proxy_enabled is not None else current.get("browser_proxy_enabled", False)
|
||||
new_proxy_url = browser_proxy_url if browser_proxy_url is not None else current.get("browser_proxy_url")
|
||||
new_browser_count = browser_count if browser_count is not None else current.get("browser_count", 1)
|
||||
new_remote_timeout = max(5, int(new_remote_timeout)) if new_remote_timeout is not None else 60
|
||||
|
||||
await db.execute("""
|
||||
UPDATE captcha_config
|
||||
@@ -1449,10 +1488,13 @@ class Database:
|
||||
capmonster_api_key = ?, capmonster_base_url = ?,
|
||||
ezcaptcha_api_key = ?, ezcaptcha_base_url = ?,
|
||||
capsolver_api_key = ?, capsolver_base_url = ?,
|
||||
remote_browser_base_url = ?, remote_browser_api_key = ?, remote_browser_timeout = ?,
|
||||
browser_proxy_enabled = ?, browser_proxy_url = ?, browser_count = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (new_method, new_yes_key, new_yes_url, new_cap_key, new_cap_url,
|
||||
new_ez_key, new_ez_url, new_cs_key, new_cs_url, new_proxy_enabled, new_proxy_url, new_browser_count))
|
||||
new_ez_key, new_ez_url, new_cs_key, new_cs_url,
|
||||
(new_remote_base_url or "").strip(), (new_remote_api_key or "").strip(), new_remote_timeout,
|
||||
new_proxy_enabled, new_proxy_url, new_browser_count))
|
||||
else:
|
||||
new_method = captcha_method if captcha_method is not None else "yescaptcha"
|
||||
new_yes_key = yescaptcha_api_key if yescaptcha_api_key is not None else ""
|
||||
@@ -1463,17 +1505,25 @@ class Database:
|
||||
new_ez_url = ezcaptcha_base_url if ezcaptcha_base_url is not None else "https://api.ez-captcha.com"
|
||||
new_cs_key = capsolver_api_key if capsolver_api_key is not None else ""
|
||||
new_cs_url = capsolver_base_url if capsolver_base_url is not None else "https://api.capsolver.com"
|
||||
new_remote_base_url = remote_browser_base_url if remote_browser_base_url is not None else ""
|
||||
new_remote_api_key = remote_browser_api_key if remote_browser_api_key is not None else ""
|
||||
new_remote_timeout = remote_browser_timeout if remote_browser_timeout is not None else 60
|
||||
new_proxy_enabled = browser_proxy_enabled if browser_proxy_enabled is not None else False
|
||||
new_proxy_url = browser_proxy_url
|
||||
new_browser_count = browser_count if browser_count is not None else 1
|
||||
new_remote_timeout = max(5, int(new_remote_timeout))
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO captcha_config (id, captcha_method, yescaptcha_api_key, yescaptcha_base_url,
|
||||
capmonster_api_key, capmonster_base_url, ezcaptcha_api_key, ezcaptcha_base_url,
|
||||
capsolver_api_key, capsolver_base_url, browser_proxy_enabled, browser_proxy_url, browser_count)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
capsolver_api_key, capsolver_base_url,
|
||||
remote_browser_base_url, remote_browser_api_key, remote_browser_timeout,
|
||||
browser_proxy_enabled, browser_proxy_url, browser_count)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (new_method, new_yes_key, new_yes_url, new_cap_key, new_cap_url,
|
||||
new_ez_key, new_ez_url, new_cs_key, new_cs_url, new_proxy_enabled, new_proxy_url, new_browser_count))
|
||||
new_ez_key, new_ez_url, new_cs_key, new_cs_url,
|
||||
(new_remote_base_url or "").strip(), (new_remote_api_key or "").strip(), new_remote_timeout,
|
||||
new_proxy_enabled, new_proxy_url, new_browser_count))
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ class DebugConfig(BaseModel):
|
||||
class CaptchaConfig(BaseModel):
|
||||
"""Captcha configuration"""
|
||||
id: int = 1
|
||||
captcha_method: str = "browser" # yescaptcha, capmonster, ezcaptcha, capsolver 或 browser
|
||||
captcha_method: str = "browser" # yescaptcha/capmonster/ezcaptcha/capsolver/browser/personal/remote_browser
|
||||
yescaptcha_api_key: str = ""
|
||||
yescaptcha_base_url: str = "https://api.yescaptcha.com"
|
||||
capmonster_api_key: str = ""
|
||||
@@ -161,6 +161,9 @@ class CaptchaConfig(BaseModel):
|
||||
ezcaptcha_base_url: str = "https://api.ez-captcha.com"
|
||||
capsolver_api_key: str = ""
|
||||
capsolver_base_url: str = "https://api.capsolver.com"
|
||||
remote_browser_base_url: str = ""
|
||||
remote_browser_api_key: str = ""
|
||||
remote_browser_timeout: int = 60
|
||||
website_key: str = "6LdsFiUsAAAAAIjVDZcuLhaHiDn5nnHVXVRQGeMV"
|
||||
page_action: str = "IMAGE_GENERATION"
|
||||
browser_proxy_enabled: bool = False # 浏览器打码是否启用代理
|
||||
|
||||
@@ -78,6 +78,9 @@ async def lifespan(app: FastAPI):
|
||||
config.set_ezcaptcha_base_url(captcha_config.ezcaptcha_base_url)
|
||||
config.set_capsolver_api_key(captcha_config.capsolver_api_key)
|
||||
config.set_capsolver_base_url(captcha_config.capsolver_base_url)
|
||||
config.set_remote_browser_base_url(captcha_config.remote_browser_base_url)
|
||||
config.set_remote_browser_api_key(captcha_config.remote_browser_api_key)
|
||||
config.set_remote_browser_timeout(captcha_config.remote_browser_timeout)
|
||||
|
||||
# Initialize browser captcha service if needed
|
||||
browser_service = None
|
||||
|
||||
@@ -12,8 +12,9 @@ import asyncio
|
||||
import time
|
||||
import re
|
||||
import random
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse, unquote, parse_qs
|
||||
|
||||
@@ -361,8 +362,9 @@ class TokenBrowser:
|
||||
self._last_fingerprint: Optional[Dict[str, Any]] = None
|
||||
self._browser_proxy_active = False
|
||||
# 打码成功后延迟关闭浏览器:等待上游图片/视频请求完成通知
|
||||
self._pending_release_events: List[asyncio.Event] = []
|
||||
self._pending_release_tasks: List[asyncio.Task] = []
|
||||
# request_ref -> {"event": asyncio.Event, "task": asyncio.Task}
|
||||
# 使用请求级句柄避免高并发下“按顺序 pop”导致的错配关闭。
|
||||
self._pending_release_entries: Dict[str, Dict[str, Any]] = {}
|
||||
self._pending_release_lock = asyncio.Lock()
|
||||
|
||||
async def _create_browser(self, token_proxy_url: Optional[str] = None) -> tuple:
|
||||
@@ -653,6 +655,7 @@ class TokenBrowser:
|
||||
|
||||
async def _wait_and_close_after_request(
|
||||
self,
|
||||
request_ref: str,
|
||||
release_event: asyncio.Event,
|
||||
wait_timeout: int,
|
||||
playwright,
|
||||
@@ -677,14 +680,10 @@ class TokenBrowser:
|
||||
finally:
|
||||
await self._close_browser(playwright, browser, context)
|
||||
debug_logger.log_info(
|
||||
f"[BrowserCaptcha] Token-{self.token_id} {close_reason},浏览器已关闭 (action={action})"
|
||||
f"[BrowserCaptcha] Token-{self.token_id} {close_reason},浏览器已关闭 (action={action}, request_ref={request_ref[:8]})"
|
||||
)
|
||||
async with self._pending_release_lock:
|
||||
current_task = asyncio.current_task()
|
||||
if current_task in self._pending_release_tasks:
|
||||
self._pending_release_tasks.remove(current_task)
|
||||
if release_event in self._pending_release_events:
|
||||
self._pending_release_events.remove(release_event)
|
||||
self._pending_release_entries.pop(request_ref, None)
|
||||
|
||||
async def _defer_browser_close_until_request_done(
|
||||
self,
|
||||
@@ -692,7 +691,7 @@ class TokenBrowser:
|
||||
browser,
|
||||
context,
|
||||
action: str
|
||||
):
|
||||
) -> str:
|
||||
"""打码成功后延迟关闭浏览器,等待 Flow 请求结束通知。"""
|
||||
flow_timeout = int(getattr(config, "flow_timeout", 300) or 300)
|
||||
upsample_timeout = int(getattr(config, "upsample_timeout", 300) or 300)
|
||||
@@ -703,9 +702,11 @@ class TokenBrowser:
|
||||
else:
|
||||
# 视频请求默认超时更长,给更大的缓冲避免“请求未结束就关闭”
|
||||
wait_timeout = max(flow_timeout + 300, 1800)
|
||||
request_ref = uuid.uuid4().hex
|
||||
release_event = asyncio.Event()
|
||||
release_task = asyncio.create_task(
|
||||
self._wait_and_close_after_request(
|
||||
request_ref=request_ref,
|
||||
release_event=release_event,
|
||||
wait_timeout=wait_timeout,
|
||||
playwright=playwright,
|
||||
@@ -716,34 +717,63 @@ class TokenBrowser:
|
||||
)
|
||||
|
||||
async with self._pending_release_lock:
|
||||
self._pending_release_events.append(release_event)
|
||||
self._pending_release_tasks.append(release_task)
|
||||
self._pending_release_entries[request_ref] = {
|
||||
"event": release_event,
|
||||
"task": release_task,
|
||||
}
|
||||
debug_logger.log_info(
|
||||
f"[BrowserCaptcha] Token-{self.token_id} 打码成功后进入延迟关闭,等待上游请求完成 (action={action}, timeout={wait_timeout}s)"
|
||||
f"[BrowserCaptcha] Token-{self.token_id} 打码成功后进入延迟关闭,等待上游请求完成 "
|
||||
f"(action={action}, timeout={wait_timeout}s, request_ref={request_ref[:8]})"
|
||||
)
|
||||
return request_ref
|
||||
|
||||
async def notify_generation_request_finished(self):
|
||||
async def notify_generation_request_finished(self, request_ref: Optional[str] = None):
|
||||
"""通知当前 Token 对应的上游图片/视频请求已结束。"""
|
||||
async with self._pending_release_lock:
|
||||
release_event = self._pending_release_events.pop(0) if self._pending_release_events else None
|
||||
release_event = None
|
||||
matched_ref = request_ref
|
||||
if matched_ref and matched_ref in self._pending_release_entries:
|
||||
entry = self._pending_release_entries.pop(matched_ref)
|
||||
release_event = entry.get("event")
|
||||
elif not matched_ref and self._pending_release_entries:
|
||||
# 兼容旧调用方(无 request_ref),仅回收最早待释放项,避免一次性影响全部请求。
|
||||
matched_ref = next(iter(self._pending_release_entries.keys()))
|
||||
entry = self._pending_release_entries.pop(matched_ref)
|
||||
release_event = entry.get("event")
|
||||
if release_event and not release_event.is_set():
|
||||
release_event.set()
|
||||
debug_logger.log_info(
|
||||
f"[BrowserCaptcha] Token-{self.token_id} 收到上游请求完成通知,开始关闭浏览器"
|
||||
f"[BrowserCaptcha] Token-{self.token_id} 收到上游请求完成通知,开始关闭浏览器 "
|
||||
f"(request_ref={(matched_ref or 'unknown')[:8]})"
|
||||
)
|
||||
|
||||
async def force_close_pending_browser(self):
|
||||
async def force_close_pending_browser(self, request_ref: Optional[str] = None, close_all: bool = False):
|
||||
"""强制关闭待释放浏览器(服务关闭时调用)。"""
|
||||
async with self._pending_release_lock:
|
||||
release_events = list(self._pending_release_events)
|
||||
release_tasks = list(self._pending_release_tasks)
|
||||
self._pending_release_events.clear()
|
||||
self._pending_release_tasks.clear()
|
||||
entries: List[Dict[str, Any]] = []
|
||||
if close_all:
|
||||
entries = list(self._pending_release_entries.values())
|
||||
self._pending_release_entries.clear()
|
||||
elif request_ref and request_ref in self._pending_release_entries:
|
||||
entry = self._pending_release_entries.pop(request_ref)
|
||||
entries = [entry]
|
||||
elif self._pending_release_entries:
|
||||
# 兼容旧调用方(无 request_ref)时,仅关闭最早的一项,避免误伤其它并发请求。
|
||||
first_ref = next(iter(self._pending_release_entries.keys()))
|
||||
entry = self._pending_release_entries.pop(first_ref)
|
||||
entries = [entry]
|
||||
|
||||
release_events = [entry.get("event") for entry in entries if isinstance(entry, dict)]
|
||||
release_tasks = [entry.get("task") for entry in entries if isinstance(entry, dict)]
|
||||
|
||||
for release_event in release_events:
|
||||
if not release_event:
|
||||
continue
|
||||
if not release_event.is_set():
|
||||
release_event.set()
|
||||
for release_task in release_tasks:
|
||||
if not release_task:
|
||||
continue
|
||||
try:
|
||||
await asyncio.wait_for(release_task, timeout=5)
|
||||
except Exception:
|
||||
@@ -1091,7 +1121,7 @@ class TokenBrowser:
|
||||
website_key: str,
|
||||
action: str = "IMAGE_GENERATION",
|
||||
token_proxy_url: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""获取 Token:启动新浏览器 -> 打码 -> 关闭浏览器"""
|
||||
async with self._semaphore:
|
||||
MAX_RETRIES = 3
|
||||
@@ -1113,7 +1143,7 @@ class TokenBrowser:
|
||||
self._solve_count += 1
|
||||
debug_logger.log_info(f"[BrowserCaptcha] Token-{self.token_id} 获取成功 ({(time.time()-start_ts)*1000:.0f}ms)")
|
||||
# 不立即关闭浏览器:等待图片/视频请求结束后再关闭
|
||||
await self._defer_browser_close_until_request_done(
|
||||
request_ref = await self._defer_browser_close_until_request_done(
|
||||
playwright=playwright,
|
||||
browser=browser,
|
||||
context=context,
|
||||
@@ -1122,7 +1152,7 @@ class TokenBrowser:
|
||||
playwright = None
|
||||
browser = None
|
||||
context = None
|
||||
return token
|
||||
return token, request_ref
|
||||
|
||||
self._error_count += 1
|
||||
debug_logger.log_warning(f"[BrowserCaptcha] Token-{self.token_id} 尝试 {attempt+1}/{MAX_RETRIES} 失败")
|
||||
@@ -1138,7 +1168,7 @@ class TokenBrowser:
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return None
|
||||
return None, None
|
||||
|
||||
async def get_custom_token(
|
||||
self,
|
||||
@@ -1364,6 +1394,32 @@ class BrowserCaptchaService:
|
||||
self._round_robin_index += 1
|
||||
return browser_id
|
||||
|
||||
@staticmethod
|
||||
def _compose_browser_ref(browser_id: int, request_ref: Optional[str]) -> Union[int, str]:
|
||||
"""将 browser_id 与 request_ref 合并为可回传的请求句柄。"""
|
||||
if request_ref:
|
||||
return f"{browser_id}:{request_ref}"
|
||||
return browser_id
|
||||
|
||||
@staticmethod
|
||||
def _parse_browser_ref(browser_ref: Optional[Union[int, str]]) -> tuple[Optional[int], Optional[str]]:
|
||||
"""解析请求句柄,兼容旧的纯 int browser_id。"""
|
||||
if browser_ref is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(browser_ref, int):
|
||||
return browser_ref, None
|
||||
|
||||
if isinstance(browser_ref, str):
|
||||
raw = browser_ref.strip()
|
||||
if raw.isdigit():
|
||||
return int(raw), None
|
||||
browser_id_part, sep, request_ref = raw.partition(":")
|
||||
if sep and browser_id_part.isdigit() and request_ref:
|
||||
return int(browser_id_part), request_ref
|
||||
|
||||
return None, None
|
||||
|
||||
async def _resolve_token_proxy_url(self, token_id: Optional[int]) -> Optional[str]:
|
||||
"""读取 token 级打码代理,为空时回退全局配置。"""
|
||||
if not token_id or not self.db:
|
||||
@@ -1376,7 +1432,7 @@ class BrowserCaptchaService:
|
||||
debug_logger.log_warning(f"[BrowserCaptcha] 读取 token({token_id}) 打码代理失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_token(self, project_id: str, action: str = "IMAGE_GENERATION", token_id: int = None) -> tuple[Optional[str], int]:
|
||||
async def get_token(self, project_id: str, action: str = "IMAGE_GENERATION", token_id: int = None) -> tuple[Optional[str], Union[int, str]]:
|
||||
"""获取 reCAPTCHA Token(轮询分配到不同浏览器)
|
||||
|
||||
Args:
|
||||
@@ -1385,7 +1441,7 @@ class BrowserCaptchaService:
|
||||
token_id: 业务 token id(仅用于读取 token 级打码代理)
|
||||
|
||||
Returns:
|
||||
(token, browser_id) 元组,调用方失败时用 browser_id 调用 report_error
|
||||
(token, browser_ref) 元组,browser_ref 包含 browser_id 与请求级 request_ref
|
||||
"""
|
||||
# 检查服务是否可用
|
||||
self._check_available()
|
||||
@@ -1400,7 +1456,7 @@ class BrowserCaptchaService:
|
||||
browser_id = self._get_next_browser_id()
|
||||
browser = await self._get_or_create_browser(browser_id)
|
||||
|
||||
token = await browser.get_token(
|
||||
token, request_ref = await browser.get_token(
|
||||
project_id,
|
||||
self.website_key,
|
||||
action,
|
||||
@@ -1413,13 +1469,13 @@ class BrowserCaptchaService:
|
||||
self._stats["gen_fail"] += 1
|
||||
|
||||
self._log_stats()
|
||||
return token, browser_id
|
||||
return token, self._compose_browser_ref(browser_id, request_ref)
|
||||
|
||||
# 无并发限制时直接执行
|
||||
browser_id = self._get_next_browser_id()
|
||||
browser = await self._get_or_create_browser(browser_id)
|
||||
|
||||
token = await browser.get_token(
|
||||
token, request_ref = await browser.get_token(
|
||||
project_id,
|
||||
self.website_key,
|
||||
action,
|
||||
@@ -1432,7 +1488,7 @@ class BrowserCaptchaService:
|
||||
self._stats["gen_fail"] += 1
|
||||
|
||||
self._log_stats()
|
||||
return token, browser_id
|
||||
return token, self._compose_browser_ref(browser_id, request_ref)
|
||||
|
||||
async def get_custom_token(
|
||||
self,
|
||||
@@ -1501,20 +1557,26 @@ class BrowserCaptchaService:
|
||||
)
|
||||
return payload, browser_id
|
||||
|
||||
async def get_fingerprint(self, browser_id: int) -> Optional[Dict[str, Any]]:
|
||||
async def get_fingerprint(self, browser_ref: Optional[Union[int, str]]) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定浏览器最近一次打码时的指纹快照。"""
|
||||
browser_id, _ = self._parse_browser_ref(browser_ref)
|
||||
if browser_id is None:
|
||||
return None
|
||||
|
||||
async with self._browsers_lock:
|
||||
browser = self._browsers.get(browser_id)
|
||||
if not browser:
|
||||
return None
|
||||
return browser.get_last_fingerprint()
|
||||
|
||||
async def report_error(self, browser_id: int = None, error_reason: Optional[str] = None):
|
||||
async def report_error(self, browser_ref: Optional[Union[int, str]] = None, error_reason: Optional[str] = None):
|
||||
"""上层举报当前请求失败,必要时提前回收待释放浏览器。
|
||||
|
||||
Args:
|
||||
browser_id: 浏览器 ID(当前架构下每次都是新浏览器,此参数仅用于日志)
|
||||
browser_ref: 浏览器请求句柄(browser_id[:request_ref])
|
||||
"""
|
||||
browser_id, request_ref = self._parse_browser_ref(browser_ref)
|
||||
|
||||
async with self._browsers_lock:
|
||||
browser = self._browsers.get(browser_id) if browser_id is not None else None
|
||||
error_lower = (error_reason or "").lower()
|
||||
@@ -1527,12 +1589,17 @@ class BrowserCaptchaService:
|
||||
|
||||
if browser:
|
||||
try:
|
||||
await browser.force_close_pending_browser()
|
||||
if request_ref:
|
||||
await browser.force_close_pending_browser(request_ref=request_ref)
|
||||
else:
|
||||
# 未携带 request_ref 时只回收一项,避免高并发下误关其它请求链路。
|
||||
await browser.force_close_pending_browser()
|
||||
except Exception as e:
|
||||
debug_logger.log_warning(f"[BrowserCaptcha] 浏览器 {browser_id} 失败后提前关闭异常: {e}")
|
||||
|
||||
async def report_request_finished(self, browser_id: int = None):
|
||||
async def report_request_finished(self, browser_ref: Optional[Union[int, str]] = None):
|
||||
"""上层通知:图片/视频请求已完成,可关闭对应打码浏览器。"""
|
||||
browser_id, request_ref = self._parse_browser_ref(browser_ref)
|
||||
if browser_id is None:
|
||||
return
|
||||
|
||||
@@ -1540,7 +1607,7 @@ class BrowserCaptchaService:
|
||||
browser = self._browsers.get(browser_id)
|
||||
|
||||
if browser:
|
||||
await browser.notify_generation_request_finished()
|
||||
await browser.notify_generation_request_finished(request_ref=request_ref)
|
||||
|
||||
async def remove_browser(self, browser_id: int):
|
||||
async with self._browsers_lock:
|
||||
@@ -1554,7 +1621,7 @@ class BrowserCaptchaService:
|
||||
|
||||
for browser in browsers:
|
||||
try:
|
||||
await browser.force_close_pending_browser()
|
||||
await browser.force_close_pending_browser(close_all=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Concurrency manager for token-based rate limiting"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
@@ -118,6 +119,40 @@ class ConcurrencyManager:
|
||||
debug_logger.log_info(f"Token {token_id} acquired image slot (inflight: {new_inflight}/{limit})")
|
||||
return True
|
||||
|
||||
async def wait_acquire_image(self, token_id: int, timeout_seconds: float) -> tuple[bool, int]:
|
||||
"""等待获取图片硬并发槽位,避免请求在短暂竞争下直接失败。"""
|
||||
wait_started = time.monotonic()
|
||||
timeout_seconds = max(1.0, float(timeout_seconds or 1.0))
|
||||
deadline = wait_started + timeout_seconds
|
||||
|
||||
while True:
|
||||
if await self.acquire_image(token_id):
|
||||
waited_ms = int((time.monotonic() - wait_started) * 1000)
|
||||
return True, waited_ms
|
||||
|
||||
if time.monotonic() >= deadline:
|
||||
waited_ms = int((time.monotonic() - wait_started) * 1000)
|
||||
return False, waited_ms
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def wait_acquire_video(self, token_id: int, timeout_seconds: float) -> tuple[bool, int]:
|
||||
"""等待获取视频硬并发槽位,避免请求在短暂竞争下直接失败。"""
|
||||
wait_started = time.monotonic()
|
||||
timeout_seconds = max(1.0, float(timeout_seconds or 1.0))
|
||||
deadline = wait_started + timeout_seconds
|
||||
|
||||
while True:
|
||||
if await self.acquire_video(token_id):
|
||||
waited_ms = int((time.monotonic() - wait_started) * 1000)
|
||||
return True, waited_ms
|
||||
|
||||
if time.monotonic() >= deadline:
|
||||
waited_ms = int((time.monotonic() - wait_started) * 1000)
|
||||
return False, waited_ms
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def acquire_video(self, token_id: int) -> bool:
|
||||
"""
|
||||
Acquire video concurrency slot
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -715,6 +715,13 @@ class GenerationHandler:
|
||||
token = None
|
||||
generation_type = None
|
||||
token_slot_reserved = False
|
||||
pending_token_state = {"active": False}
|
||||
request_id = f"gen-{int(start_time * 1000)}-{id(asyncio.current_task())}"
|
||||
perf_trace: Dict[str, Any] = {
|
||||
"request_id": request_id,
|
||||
"model": model,
|
||||
"status": "processing",
|
||||
}
|
||||
self._last_generated_url = None
|
||||
self._last_generation_assets = None
|
||||
|
||||
@@ -762,19 +769,25 @@ class GenerationHandler:
|
||||
|
||||
# 2. 选择Token
|
||||
debug_logger.log_info(f"[GENERATION] 正在选择可用Token...")
|
||||
token_select_started_at = time.time()
|
||||
|
||||
if generation_type == "image":
|
||||
token = await self.load_balancer.select_token(
|
||||
for_image_generation=True,
|
||||
model=model,
|
||||
reserve=self.concurrency_manager is not None
|
||||
reserve=False,
|
||||
enforce_concurrency_filter=False,
|
||||
track_pending=True,
|
||||
)
|
||||
else:
|
||||
token = await self.load_balancer.select_token(
|
||||
for_video_generation=True,
|
||||
model=model,
|
||||
reserve=self.concurrency_manager is not None
|
||||
reserve=False,
|
||||
enforce_concurrency_filter=False,
|
||||
track_pending=True,
|
||||
)
|
||||
perf_trace["token_select_ms"] = int((time.time() - token_select_started_at) * 1000)
|
||||
|
||||
if not token:
|
||||
error_msg = self._get_no_token_error_message(generation_type)
|
||||
@@ -784,7 +797,7 @@ class GenerationHandler:
|
||||
yield self._create_error_response(error_msg)
|
||||
return
|
||||
|
||||
token_slot_reserved = self.concurrency_manager is not None
|
||||
token_slot_reserved = False
|
||||
debug_logger.log_info(f"[GENERATION] 已选择Token: {token.id} ({token.email})")
|
||||
|
||||
try:
|
||||
@@ -793,7 +806,9 @@ class GenerationHandler:
|
||||
if stream:
|
||||
yield self._create_stream_chunk("初始化生成环境...\n")
|
||||
|
||||
ensure_at_started_at = time.time()
|
||||
token = await self.token_manager.ensure_valid_token(token)
|
||||
perf_trace["ensure_at_ms"] = int((time.time() - ensure_at_started_at) * 1000)
|
||||
if not token:
|
||||
error_msg = "Token AT无效或刷新失败"
|
||||
debug_logger.log_error(f"[GENERATION] {error_msg}")
|
||||
@@ -805,16 +820,22 @@ class GenerationHandler:
|
||||
# 4. 确保Project存在
|
||||
debug_logger.log_info(f"[GENERATION] 检查/创建Project...")
|
||||
|
||||
ensure_project_started_at = time.time()
|
||||
project_id = await self.token_manager.ensure_project_exists(token.id)
|
||||
perf_trace["ensure_project_ms"] = int((time.time() - ensure_project_started_at) * 1000)
|
||||
debug_logger.log_info(f"[GENERATION] Project ID: {project_id}")
|
||||
|
||||
# 5. 根据类型处理
|
||||
generation_pipeline_started_at = time.time()
|
||||
if generation_type == "image":
|
||||
debug_logger.log_info(f"[GENERATION] 开始图片生成流程...")
|
||||
slot_reserved_for_handler = token_slot_reserved
|
||||
token_slot_reserved = False
|
||||
async for chunk in self._handle_image_generation(
|
||||
token, project_id, model_config, prompt, images, stream, slot_reserved=slot_reserved_for_handler
|
||||
token, project_id, model_config, prompt, images, stream,
|
||||
slot_reserved=slot_reserved_for_handler,
|
||||
perf_trace=perf_trace,
|
||||
pending_token_state=pending_token_state
|
||||
):
|
||||
yield chunk
|
||||
else: # video
|
||||
@@ -822,9 +843,13 @@ class GenerationHandler:
|
||||
slot_reserved_for_handler = token_slot_reserved
|
||||
token_slot_reserved = False
|
||||
async for chunk in self._handle_video_generation(
|
||||
token, project_id, model_config, prompt, images, stream, slot_reserved=slot_reserved_for_handler
|
||||
token, project_id, model_config, prompt, images, stream,
|
||||
slot_reserved=slot_reserved_for_handler,
|
||||
perf_trace=perf_trace,
|
||||
pending_token_state=pending_token_state
|
||||
):
|
||||
yield chunk
|
||||
perf_trace["generation_pipeline_ms"] = int((time.time() - generation_pipeline_started_at) * 1000)
|
||||
|
||||
# 6. 记录使用
|
||||
is_video = (generation_type == "video")
|
||||
@@ -837,6 +862,8 @@ class GenerationHandler:
|
||||
|
||||
# 7. 记录成功日志
|
||||
duration = time.time() - start_time
|
||||
perf_trace["status"] = "success"
|
||||
perf_trace["total_ms"] = int(duration * 1000)
|
||||
# 日志中保留更完整的 prompt,避免管理页只看到过短内容
|
||||
prompt_for_log = prompt if len(prompt) <= 2000 else f"{prompt[:2000]}...(truncated)"
|
||||
|
||||
@@ -844,7 +871,8 @@ class GenerationHandler:
|
||||
response_data = {
|
||||
"status": "success",
|
||||
"model": model,
|
||||
"prompt": prompt_for_log
|
||||
"prompt": prompt_for_log,
|
||||
"performance": perf_trace
|
||||
}
|
||||
|
||||
# 添加生成的URL(如果有)
|
||||
@@ -856,6 +884,19 @@ class GenerationHandler:
|
||||
# 清除临时存储,避免污染后续请求
|
||||
self._last_generated_url = None
|
||||
self._last_generation_assets = None
|
||||
image_perf = perf_trace.get("image_generation", {}) if isinstance(perf_trace, dict) else {}
|
||||
video_perf = perf_trace.get("video_generation", {}) if isinstance(perf_trace, dict) else {}
|
||||
debug_logger.log_info(
|
||||
f"[PERF] [{request_id}] total={perf_trace.get('total_ms', 0)}ms, "
|
||||
f"select={perf_trace.get('token_select_ms', 0)}ms, "
|
||||
f"ensure_at={perf_trace.get('ensure_at_ms', 0)}ms, "
|
||||
f"project={perf_trace.get('ensure_project_ms', 0)}ms, "
|
||||
f"pipeline={perf_trace.get('generation_pipeline_ms', 0)}ms, "
|
||||
f"slot_wait={image_perf.get('slot_wait_ms', 0)}ms, "
|
||||
f"launch_queue={image_perf.get('launch_queue_wait_ms', 0)}ms, "
|
||||
f"launch_stagger={image_perf.get('launch_stagger_wait_ms', 0)}ms, "
|
||||
f"video_slot_wait={video_perf.get('slot_wait_ms', 0)}ms"
|
||||
)
|
||||
|
||||
await self._log_request(
|
||||
token.id,
|
||||
@@ -878,16 +919,27 @@ class GenerationHandler:
|
||||
|
||||
# 记录失败日志
|
||||
duration = time.time() - start_time
|
||||
perf_trace["status"] = "failed"
|
||||
perf_trace["total_ms"] = int(duration * 1000)
|
||||
perf_trace["error"] = error_msg
|
||||
prompt_for_log = prompt if len(prompt) <= 2000 else f"{prompt[:2000]}...(truncated)"
|
||||
await self._log_request(
|
||||
token.id if token else None,
|
||||
f"generate_{generation_type if model_config else 'unknown'}",
|
||||
{"model": model, "prompt": prompt_for_log, "has_images": images is not None and len(images) > 0},
|
||||
{"error": error_msg},
|
||||
{"error": error_msg, "performance": perf_trace},
|
||||
500,
|
||||
duration
|
||||
)
|
||||
finally:
|
||||
if pending_token_state.get("active") and token and self.load_balancer:
|
||||
await self.load_balancer.release_pending(
|
||||
token.id,
|
||||
for_image_generation=(generation_type == "image"),
|
||||
for_video_generation=(generation_type == "video"),
|
||||
)
|
||||
pending_token_state["active"] = False
|
||||
|
||||
if token_slot_reserved and token and self.concurrency_manager:
|
||||
if generation_type == "image":
|
||||
await self.concurrency_manager.release_image(token.id)
|
||||
@@ -909,21 +961,34 @@ class GenerationHandler:
|
||||
prompt: str,
|
||||
images: Optional[List[bytes]],
|
||||
stream: bool,
|
||||
slot_reserved: bool = False
|
||||
slot_reserved: bool = False,
|
||||
perf_trace: Optional[Dict[str, Any]] = None,
|
||||
pending_token_state: Optional[Dict[str, bool]] = None
|
||||
) -> AsyncGenerator:
|
||||
"""处理图片生成 (同步返回)"""
|
||||
|
||||
slot_acquired = False
|
||||
image_trace: Optional[Dict[str, Any]] = None
|
||||
if isinstance(perf_trace, dict):
|
||||
image_trace = perf_trace.setdefault("image_generation", {})
|
||||
image_trace["input_image_count"] = len(images) if images else 0
|
||||
|
||||
# 获取并发槽位
|
||||
if self.concurrency_manager and not slot_reserved:
|
||||
if not await self.concurrency_manager.acquire_image(token.id):
|
||||
slot_ok, slot_wait_ms = await self.concurrency_manager.wait_acquire_image(
|
||||
token.id,
|
||||
timeout_seconds=config.flow_image_slot_wait_timeout
|
||||
)
|
||||
if image_trace is not None:
|
||||
image_trace["slot_wait_ms"] = slot_wait_ms
|
||||
if not slot_ok:
|
||||
yield self._create_error_response("图片并发限制已达上限")
|
||||
return
|
||||
slot_acquired = True
|
||||
|
||||
try:
|
||||
# 上传图片 (如果有)
|
||||
upload_started_at = time.time()
|
||||
image_inputs = []
|
||||
if images and len(images) > 0:
|
||||
if stream:
|
||||
@@ -943,20 +1008,32 @@ class GenerationHandler:
|
||||
})
|
||||
if stream:
|
||||
yield self._create_stream_chunk(f"已上传第 {idx + 1}/{len(images)} 张图片\n")
|
||||
if image_trace is not None:
|
||||
image_trace["upload_images_ms"] = int((time.time() - upload_started_at) * 1000)
|
||||
|
||||
# 调用生成API
|
||||
if stream:
|
||||
yield self._create_stream_chunk("正在生成图片...\n")
|
||||
|
||||
result, generation_session_id = await self.flow_client.generate_image(
|
||||
generate_started_at = time.time()
|
||||
result, generation_session_id, upstream_trace = await self.flow_client.generate_image(
|
||||
at=token.at,
|
||||
project_id=project_id,
|
||||
prompt=prompt,
|
||||
model_name=model_config["model_name"],
|
||||
aspect_ratio=model_config["aspect_ratio"],
|
||||
image_inputs=image_inputs,
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_image_concurrency=token.image_concurrency,
|
||||
)
|
||||
if image_trace is not None:
|
||||
image_trace["generate_api_ms"] = int((time.time() - generate_started_at) * 1000)
|
||||
image_trace["upstream_trace"] = upstream_trace
|
||||
attempts = upstream_trace.get("generation_attempts") if isinstance(upstream_trace, dict) else None
|
||||
if isinstance(attempts, list) and attempts:
|
||||
first_attempt = attempts[0] if isinstance(attempts[0], dict) else {}
|
||||
image_trace["launch_queue_wait_ms"] = int(first_attempt.get("launch_queue_ms") or 0)
|
||||
image_trace["launch_stagger_wait_ms"] = int(first_attempt.get("launch_stagger_ms") or 0)
|
||||
|
||||
# 提取URL和mediaId
|
||||
media = result.get("media", [])
|
||||
@@ -974,6 +1051,7 @@ class GenerationHandler:
|
||||
# 检查是否需要 upsample
|
||||
upsample_resolution = model_config.get("upsample")
|
||||
if upsample_resolution and media_id:
|
||||
upsample_started_at = time.time()
|
||||
resolution_name = "4K" if "4K" in upsample_resolution else "2K"
|
||||
if stream:
|
||||
yield self._create_stream_chunk(f"正在放大图片到 {resolution_name}...\n")
|
||||
@@ -1030,6 +1108,8 @@ class GenerationHandler:
|
||||
local_url,
|
||||
media_type="image"
|
||||
)
|
||||
if image_trace is not None:
|
||||
image_trace["upsample_ms"] = int((time.time() - upsample_started_at) * 1000)
|
||||
return
|
||||
except Exception as e:
|
||||
debug_logger.log_error(f"Failed to cache {resolution_name} image: {str(e)}")
|
||||
@@ -1050,6 +1130,8 @@ class GenerationHandler:
|
||||
base64_url,
|
||||
media_type="image"
|
||||
)
|
||||
if image_trace is not None:
|
||||
image_trace["upsample_ms"] = int((time.time() - upsample_started_at) * 1000)
|
||||
return
|
||||
else:
|
||||
debug_logger.log_warning("[UPSAMPLE] 返回结果为空")
|
||||
@@ -1073,9 +1155,12 @@ class GenerationHandler:
|
||||
if stream:
|
||||
yield self._create_stream_chunk(f"⚠️ 放大失败: {error_str},返回原图...\n")
|
||||
break
|
||||
if image_trace is not None:
|
||||
image_trace["upsample_ms"] = int((time.time() - upsample_started_at) * 1000)
|
||||
|
||||
# 缓存图片 (如果启用)
|
||||
local_url = image_url
|
||||
cache_started_at = time.time()
|
||||
if config.cache_enabled:
|
||||
try:
|
||||
if stream:
|
||||
@@ -1093,6 +1178,8 @@ class GenerationHandler:
|
||||
else:
|
||||
if stream:
|
||||
yield self._create_stream_chunk("缓存已关闭,正在返回源链接...\n")
|
||||
if image_trace is not None:
|
||||
image_trace["cache_image_ms"] = int((time.time() - cache_started_at) * 1000)
|
||||
|
||||
# 返回结果
|
||||
# 存储URL用于日志记录
|
||||
@@ -1127,15 +1214,27 @@ class GenerationHandler:
|
||||
prompt: str,
|
||||
images: Optional[List[bytes]],
|
||||
stream: bool,
|
||||
slot_reserved: bool = False
|
||||
slot_reserved: bool = False,
|
||||
perf_trace: Optional[Dict[str, Any]] = None,
|
||||
pending_token_state: Optional[Dict[str, bool]] = None
|
||||
) -> AsyncGenerator:
|
||||
"""处理视频生成 (异步轮询)"""
|
||||
|
||||
slot_acquired = False
|
||||
video_trace: Optional[Dict[str, Any]] = None
|
||||
if isinstance(perf_trace, dict):
|
||||
video_trace = perf_trace.setdefault("video_generation", {})
|
||||
video_trace["input_image_count"] = len(images) if images else 0
|
||||
|
||||
# 获取并发槽位
|
||||
if self.concurrency_manager and not slot_reserved:
|
||||
if not await self.concurrency_manager.acquire_video(token.id):
|
||||
slot_ok, slot_wait_ms = await self.concurrency_manager.wait_acquire_video(
|
||||
token.id,
|
||||
timeout_seconds=config.flow_video_slot_wait_timeout
|
||||
)
|
||||
if video_trace is not None:
|
||||
video_trace["slot_wait_ms"] = slot_wait_ms
|
||||
if not slot_ok:
|
||||
yield self._create_error_response("视频并发限制已达上限")
|
||||
return
|
||||
slot_acquired = True
|
||||
@@ -1260,6 +1359,7 @@ class GenerationHandler:
|
||||
# ========== 调用生成API ==========
|
||||
if stream:
|
||||
yield self._create_stream_chunk("提交视频生成任务...\n")
|
||||
submit_started_at = time.time()
|
||||
|
||||
# I2V: 首尾帧生成
|
||||
if video_type == "i2v" and start_media_id:
|
||||
@@ -1274,7 +1374,8 @@ class GenerationHandler:
|
||||
start_media_id=start_media_id,
|
||||
end_media_id=end_media_id,
|
||||
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_video_concurrency=token.video_concurrency,
|
||||
)
|
||||
else:
|
||||
# 只有首帧 - 需要去掉 model_key 中的 _fl
|
||||
@@ -1292,7 +1393,8 @@ class GenerationHandler:
|
||||
aspect_ratio=model_config["aspect_ratio"],
|
||||
start_media_id=start_media_id,
|
||||
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_video_concurrency=token.video_concurrency,
|
||||
)
|
||||
|
||||
# R2V: 多图生成
|
||||
@@ -1305,7 +1407,8 @@ class GenerationHandler:
|
||||
aspect_ratio=model_config["aspect_ratio"],
|
||||
reference_images=reference_images,
|
||||
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_video_concurrency=token.video_concurrency,
|
||||
)
|
||||
|
||||
# T2V 或 R2V无图: 纯文本生成
|
||||
@@ -1317,8 +1420,11 @@ class GenerationHandler:
|
||||
model_key=model_config["model_key"],
|
||||
aspect_ratio=model_config["aspect_ratio"],
|
||||
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_video_concurrency=token.video_concurrency,
|
||||
)
|
||||
if video_trace is not None:
|
||||
video_trace["submit_generation_ms"] = int((time.time() - submit_started_at) * 1000)
|
||||
|
||||
# 获取task_id和operations
|
||||
operations = result.get("operations", [])
|
||||
@@ -1424,7 +1530,8 @@ class GenerationHandler:
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=upsample_config["resolution"],
|
||||
model_key=upsample_config["model_key"],
|
||||
token_id=token.id
|
||||
token_id=token.id,
|
||||
token_video_concurrency=token.video_concurrency,
|
||||
)
|
||||
|
||||
upsample_operations = upsample_result.get("operations", [])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load balancing module for Flow2API"""
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
from ..core.models import Token
|
||||
from .concurrency_manager import ConcurrencyManager
|
||||
from ..core.logger import debug_logger
|
||||
@@ -12,6 +13,39 @@ class LoadBalancer:
|
||||
def __init__(self, token_manager, concurrency_manager: Optional[ConcurrencyManager] = None):
|
||||
self.token_manager = token_manager
|
||||
self.concurrency_manager = concurrency_manager
|
||||
self._image_pending: Dict[int, int] = {}
|
||||
self._video_pending: Dict[int, int] = {}
|
||||
self._pending_lock = asyncio.Lock()
|
||||
|
||||
async def _get_pending_count(self, token_id: int, for_image_generation: bool, for_video_generation: bool) -> int:
|
||||
async with self._pending_lock:
|
||||
if for_image_generation:
|
||||
return max(0, int(self._image_pending.get(token_id, 0)))
|
||||
if for_video_generation:
|
||||
return max(0, int(self._video_pending.get(token_id, 0)))
|
||||
return 0
|
||||
|
||||
async def _add_pending(self, token_id: int, for_image_generation: bool, for_video_generation: bool):
|
||||
async with self._pending_lock:
|
||||
if for_image_generation:
|
||||
self._image_pending[token_id] = max(0, int(self._image_pending.get(token_id, 0))) + 1
|
||||
elif for_video_generation:
|
||||
self._video_pending[token_id] = max(0, int(self._video_pending.get(token_id, 0))) + 1
|
||||
|
||||
async def release_pending(self, token_id: int, for_image_generation: bool = False, for_video_generation: bool = False):
|
||||
async with self._pending_lock:
|
||||
if for_image_generation:
|
||||
current = max(0, int(self._image_pending.get(token_id, 0)))
|
||||
if current <= 1:
|
||||
self._image_pending.pop(token_id, None)
|
||||
else:
|
||||
self._image_pending[token_id] = current - 1
|
||||
elif for_video_generation:
|
||||
current = max(0, int(self._video_pending.get(token_id, 0)))
|
||||
if current <= 1:
|
||||
self._video_pending.pop(token_id, None)
|
||||
else:
|
||||
self._video_pending[token_id] = current - 1
|
||||
|
||||
async def _get_token_load(self, token_id: int, for_image_generation: bool, for_video_generation: bool) -> tuple[int, Optional[int]]:
|
||||
"""获取 token 当前负载。
|
||||
@@ -26,12 +60,20 @@ class LoadBalancer:
|
||||
if for_image_generation:
|
||||
inflight = await self.concurrency_manager.get_image_inflight(token_id)
|
||||
remaining = await self.concurrency_manager.get_image_remaining(token_id)
|
||||
return inflight, remaining
|
||||
pending = await self._get_pending_count(token_id, True, False)
|
||||
effective_inflight = inflight + pending
|
||||
if remaining is not None:
|
||||
remaining = max(0, remaining - pending)
|
||||
return effective_inflight, remaining
|
||||
|
||||
if for_video_generation:
|
||||
inflight = await self.concurrency_manager.get_video_inflight(token_id)
|
||||
remaining = await self.concurrency_manager.get_video_remaining(token_id)
|
||||
return inflight, remaining
|
||||
pending = await self._get_pending_count(token_id, False, True)
|
||||
effective_inflight = inflight + pending
|
||||
if remaining is not None:
|
||||
remaining = max(0, remaining - pending)
|
||||
return effective_inflight, remaining
|
||||
|
||||
return 0, None
|
||||
|
||||
@@ -53,7 +95,9 @@ class LoadBalancer:
|
||||
for_image_generation: bool = False,
|
||||
for_video_generation: bool = False,
|
||||
model: Optional[str] = None,
|
||||
reserve: bool = False
|
||||
reserve: bool = False,
|
||||
enforce_concurrency_filter: bool = True,
|
||||
track_pending: bool = False,
|
||||
) -> Optional[Token]:
|
||||
"""
|
||||
Select a token using load-aware balancing
|
||||
@@ -63,6 +107,13 @@ class LoadBalancer:
|
||||
for_video_generation: If True, only select tokens with video_enabled=True
|
||||
model: Model name (used to filter tokens for specific models)
|
||||
reserve: Whether to atomically reserve one concurrency slot for the selected token
|
||||
enforce_concurrency_filter:
|
||||
Whether to pre-filter tokens by current inflight/remaining capacity.
|
||||
For reserve=False generation paths, this should usually be False so
|
||||
requests can enter the downstream wait queue instead of failing fast.
|
||||
track_pending:
|
||||
Whether to count the selected token as a queued request immediately.
|
||||
This smooths burst distribution before the hard concurrency slot is acquired.
|
||||
|
||||
Returns:
|
||||
Selected token or None if no available tokens
|
||||
@@ -88,7 +139,11 @@ class LoadBalancer:
|
||||
filtered_reasons[token.id] = "图片生成已禁用"
|
||||
continue
|
||||
|
||||
if self.concurrency_manager and not await self.concurrency_manager.can_use_image(token.id):
|
||||
if (
|
||||
enforce_concurrency_filter
|
||||
and self.concurrency_manager
|
||||
and not await self.concurrency_manager.can_use_image(token.id)
|
||||
):
|
||||
filtered_reasons[token.id] = "图片并发已满"
|
||||
continue
|
||||
|
||||
@@ -97,7 +152,11 @@ class LoadBalancer:
|
||||
filtered_reasons[token.id] = "视频生成已禁用"
|
||||
continue
|
||||
|
||||
if self.concurrency_manager and not await self.concurrency_manager.can_use_video(token.id):
|
||||
if (
|
||||
enforce_concurrency_filter
|
||||
and self.concurrency_manager
|
||||
and not await self.concurrency_manager.can_use_video(token.id)
|
||||
):
|
||||
filtered_reasons[token.id] = "视频并发已满"
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user