From db559a82c583cf15c0dcd66251dae373f59405a5 Mon Sep 17 00:00:00 2001 From: genz27 Date: Mon, 16 Mar 2026 21:44:32 +0800 Subject: [PATCH] fix: improve gemini compatibility and cache flow 1. add Gemini /models and /v1beta/models discovery endpoints for newapi compatibility 2. persist tmp cache in Docker and lock cache_timeout=0 behavior with regression tests 3. refine image generation status flow to show captcha verification after uploads --- .gitignore | 2 +- README.md | 2 + docker-compose.headed.yml | 1 + docker-compose.local.yml | 1 + docker-compose.proxy.yml | 1 + docker-compose.yml | 1 + src/api/routes.py | 104 ++++++++++++++++++++---- src/services/flow_client.py | 7 +- src/services/generation_handler.py | 26 +++++- static/manage.html | 2 +- tests/test_file_cache.py | 31 +++++++ tests/test_gemini_generate_content.py | 25 ++++++ tests/test_generation_handler.py | 111 ++++++++++++++++++++++++++ 13 files changed, 293 insertions(+), 21 deletions(-) create mode 100644 tests/test_file_cache.py create mode 100644 tests/test_generation_handler.py diff --git a/.gitignore b/.gitignore index 89f5639..4c133db 100644 --- a/.gitignore +++ b/.gitignore @@ -58,7 +58,7 @@ browser_data browser_data_rt data +tmp/ config/setting.toml config/setting_warp.toml config/setting_warp_example.toml -tmp/browser_pids/ diff --git a/README.md b/README.md index d48bd4d..c4112d4 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ docker-compose up -d docker-compose logs -f ``` +> 说明:Compose 已默认挂载 `./tmp:/app/tmp`。如果把缓存超时设为 `0`,语义是“不自动过期删除”;若希望容器重建后仍保留缓存文件,也需要保留这个 `tmp` 挂载。 + #### WARP 模式(使用代理) ```bash diff --git a/docker-compose.headed.yml b/docker-compose.headed.yml index f521467..dcf489d 100644 --- a/docker-compose.headed.yml +++ b/docker-compose.headed.yml @@ -11,6 +11,7 @@ services: - "8000:8000" volumes: - ./data:/app/data + - ./tmp:/app/tmp - ./config/setting.toml:/app/config/setting.toml environment: - PYTHONUNBUFFERED=1 diff --git a/docker-compose.local.yml b/docker-compose.local.yml index 8e55b4e..a7c8cad 100644 --- a/docker-compose.local.yml +++ b/docker-compose.local.yml @@ -11,6 +11,7 @@ services: - "38000:8000" volumes: - ./data:/app/data + - ./tmp:/app/tmp - ./config/setting.toml:/app/config/setting.toml environment: - PYTHONUNBUFFERED=1 diff --git a/docker-compose.proxy.yml b/docker-compose.proxy.yml index 0cb3265..07b2221 100644 --- a/docker-compose.proxy.yml +++ b/docker-compose.proxy.yml @@ -8,6 +8,7 @@ services: - "38000:8000" volumes: - ./data:/app/data + - ./tmp:/app/tmp - ./config/setting_warp.toml:/app/config/setting.toml environment: - PYTHONUNBUFFERED=1 diff --git a/docker-compose.yml b/docker-compose.yml index 60d2f1c..ae005d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,7 @@ services: - "38000:8000" volumes: - ./data:/app/data + - ./tmp:/app/tmp - ./config/setting.toml:/app/config/setting.toml environment: - PYTHONUNBUFFERED=1 diff --git a/src/api/routes.py b/src/api/routes.py index 569ab5e..087a764 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -67,6 +67,56 @@ def _ensure_generation_handler() -> GenerationHandler: return generation_handler +def _build_model_description(model_config: Dict[str, Any]) -> str: + """Build a human-readable description for model listing endpoints.""" + description = f"{model_config['type'].capitalize()} generation" + if model_config["type"] == "image": + description += f" - {model_config['model_name']}" + else: + description += f" - {model_config['model_key']}" + return description + + +def _get_openai_model_catalog() -> List[Dict[str, str]]: + """Collect OpenAI-compatible model list entries.""" + return [ + { + "id": model_id, + "description": _build_model_description(model_config), + } + for model_id, model_config in MODEL_CONFIG.items() + ] + + +def _get_gemini_model_catalog() -> Dict[str, str]: + """Collect Gemini-compatible model metadata for /models endpoints.""" + catalog: Dict[str, str] = {} + + for alias_id, description in get_base_model_aliases().items(): + catalog[alias_id] = description + + for model_id, model_config in MODEL_CONFIG.items(): + catalog.setdefault(model_id, _build_model_description(model_config)) + + return catalog + + +def _build_gemini_model_resource(model_id: str, description: str) -> Dict[str, Any]: + """Build a Gemini-compatible model resource payload.""" + return { + "name": f"models/{model_id}", + "displayName": model_id, + "description": description, + "version": "flow2api", + "inputTokenLimit": 0, + "outputTokenLimit": 0, + "supportedGenerationMethods": [ + "generateContent", + "streamGenerateContent", + ], + } + + def _decode_data_url(data_url: str) -> tuple[str, bytes]: match = DATA_URL_RE.match(data_url) if not match: @@ -601,23 +651,15 @@ async def _iterate_gemini_stream( @router.get("/v1/models") async def list_models(api_key: str = Depends(verify_api_key_flexible)): """List available models.""" - models = [] - - for model_id, config in MODEL_CONFIG.items(): - description = f"{config['type'].capitalize()} generation" - if config["type"] == "image": - description += f" - {config['model_name']}" - else: - description += f" - {config['model_key']}" - - models.append( - { - "id": model_id, - "object": "model", - "owned_by": "flow2api", - "description": description, - } - ) + models = [ + { + "id": model["id"], + "object": "model", + "owned_by": "flow2api", + "description": model["description"], + } + for model in _get_openai_model_catalog() + ] return {"object": "list", "data": models} @@ -640,6 +682,34 @@ async def list_model_aliases(api_key: str = Depends(verify_api_key_flexible)): return {"object": "list", "data": alias_models} +@router.get("/v1beta/models") +@router.get("/models") +async def list_gemini_models(api_key: str = Depends(verify_api_key_flexible)): + """List available models using Gemini-compatible response shape.""" + catalog = _get_gemini_model_catalog() + return { + "models": [ + _build_gemini_model_resource(model_id, description) + for model_id, description in catalog.items() + ] + } + + +@router.get("/v1beta/models/{model}") +@router.get("/models/{model}") +async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_flexible)): + """Return a single model using Gemini-compatible response shape.""" + catalog = _get_gemini_model_catalog() + description = catalog.get(model) + if not description: + return JSONResponse( + status_code=404, + content=_build_gemini_error_payload(404, f"Model not found: {model}"), + ) + + return _build_gemini_model_resource(model, description) + + @router.post("/v1/chat/completions") async def create_chat_completion( request: ChatCompletionRequest, diff --git a/src/services/flow_client.py b/src/services/flow_client.py index 3aebe73..0d0ef2a 100644 --- a/src/services/flow_client.py +++ b/src/services/flow_client.py @@ -7,7 +7,7 @@ import uuid import random import base64 import ssl -from typing import Dict, Any, Optional, List, Union +from typing import Dict, Any, Optional, List, Union, Callable, Awaitable from urllib.parse import quote import urllib.error import urllib.request @@ -895,6 +895,7 @@ class FlowClient: image_inputs: Optional[List[Dict]] = None, token_id: Optional[int] = None, token_image_concurrency: Optional[int] = None, + progress_callback: Optional[Callable[[str, int], Awaitable[None]]] = None, ) -> tuple[dict, str, Dict[str, Any]]: """生成图片(同步返回) @@ -930,6 +931,8 @@ class FlowClient: attempt_started_at = time.time() # 每次重试都重新获取 reCAPTCHA token recaptcha_started_at = time.time() + if progress_callback is not None: + await progress_callback("solving_image_captcha", 38) launch_gate_acquired = False launch_ok, launch_queue_ms, launch_stagger_ms = await self._acquire_image_launch_gate( token_id=token_id, @@ -973,6 +976,8 @@ class FlowClient: if should_retry: continue raise last_error + if progress_callback is not None: + await progress_callback("submitting_image", 48) session_id = self._generate_session_id() # 构建请求 - 新版接口在外层和 requests 内都带 clientContext diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index d81ce47..f026678 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -1148,7 +1148,18 @@ class GenerationHandler: # 调用生成API if stream: - yield self._create_stream_chunk("正在生成图片...\n") + if images and len(images) > 0: + yield self._create_stream_chunk("参考图片上传完成,正在进行打码验证...\n") + else: + yield self._create_stream_chunk("正在进行打码验证并提交图片生成请求...\n") + + async def _image_progress_callback(status_text: str, progress: int): + await self._update_request_log_progress( + request_log_state, + token_id=token.id, + status_text=status_text, + progress=progress, + ) generate_started_at = time.time() result, generation_session_id, upstream_trace = await self.flow_client.generate_image( @@ -1160,6 +1171,7 @@ class GenerationHandler: image_inputs=image_inputs, token_id=token.id, token_image_concurrency=token.image_concurrency, + progress_callback=_image_progress_callback, ) if image_trace is not None: image_trace["generate_api_ms"] = int((time.time() - generate_started_at) * 1000) @@ -1169,6 +1181,12 @@ class GenerationHandler: first_attempt = attempts[0] if isinstance(attempts[0], dict) else {} image_trace["launch_queue_wait_ms"] = int(first_attempt.get("launch_queue_ms") or 0) image_trace["launch_stagger_wait_ms"] = int(first_attempt.get("launch_stagger_ms") or 0) + await self._update_request_log_progress( + request_log_state, + token_id=token.id, + status_text="image_generated", + progress=72, + ) # 提取URL和mediaId media = result.get("media", []) @@ -1228,6 +1246,12 @@ class GenerationHandler: if config.cache_enabled: try: + await self._update_request_log_progress( + request_log_state, + token_id=token.id, + status_text="caching_image", + progress=90, + ) if stream: yield self._create_stream_chunk(f"缓存 {resolution_name} 图片中...\n") cached_filename = await self.file_cache.cache_base64_image(encoded_image, resolution_name) diff --git a/static/manage.html b/static/manage.html index 4b622ea..8171cef 100644 --- a/static/manage.html +++ b/static/manage.html @@ -827,7 +827,7 @@ generateRandomToken=()=>{const chars='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';let token='';for(let i=0;i<32;i++){token+=chars.charAt(Math.floor(Math.random()*chars.length))}$('cfgPluginConnectionToken').value=token;showToast('随机Token已生成','success')}, toggleATAutoRefresh=async()=>{try{const enabled=$('atAutoRefreshToggle').checked;const r=await apiRequest('/api/token-refresh/enabled',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r){$('atAutoRefreshToggle').checked=!enabled;return}const d=await r.json();if(d.success){showToast(enabled?'AT自动刷新已启用':'AT自动刷新已禁用','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('atAutoRefreshToggle').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('atAutoRefreshToggle').checked=!enabled}}, loadATAutoRefreshConfig=async()=>{try{const r=await apiRequest('/api/token-refresh/config');if(!r)return;const d=await r.json();if(d.success&&d.config){$('atAutoRefreshToggle').checked=d.config.at_auto_refresh_enabled||false}else{console.error('AT自动刷新配置数据格式错误:',d)}}catch(e){console.error('加载AT自动刷新配置失败:',e)}}, - formatLogStatus=l=>{const statusText=(l.status_text||'').trim();if(statusText){const map={started:'\u5df2\u542f\u52a8',token_selected:'\u5df2\u9009\u4e2d\u8d26\u53f7',token_ready:'\u51c6\u5907\u751f\u6210\u73af\u5883',project_ready:'\u9879\u76ee\u5df2\u5c31\u7eea',uploading_images:'\u4e0a\u4f20\u53c2\u8003\u56fe\u4e2d',submitting_image:'\u56fe\u7247\u63d0\u4ea4\u4e2d',image_generated:'\u56fe\u7247\u751f\u6210\u5b8c\u6210',preparing_video:'\u51c6\u5907\u89c6\u9891\u4efb\u52a1',submitting_video:'\u89c6\u9891\u63d0\u4ea4\u4e2d',video_submitted:'\u89c6\u9891\u4efb\u52a1\u5df2\u63d0\u4ea4',video_polling:'\u89c6\u9891\u751f\u6210\u4e2d',caching_image:'\u7f13\u5b58\u56fe\u7247\u4e2d',caching_video:'\u7f13\u5b58\u89c6\u9891\u4e2d',completed:'\u5df2\u5b8c\u6210',failed:'\u5931\u8d25',processing:'\u5904\u7406\u4e2d',upsampling_2k:'\u6b63\u5728\u653e\u5927\u52302K',upsampling_4k:'\u6b63\u5728\u653e\u5927\u52304K',upsampling_1080p:'\u6b63\u5728\u653e\u5927\u52301080P'};return map[statusText]||statusText}if(l.status_code===102)return'\u5904\u7406\u4e2d';if(l.status_code===200)return'\u5df2\u5b8c\u6210';if(l.status_code&&l.status_code>=400)return'\u5931\u8d25';return'-'}, + formatLogStatus=l=>{const statusText=(l.status_text||'').trim();if(statusText){const map={started:'\u5df2\u542f\u52a8',token_selected:'\u5df2\u9009\u4e2d\u8d26\u53f7',token_ready:'\u51c6\u5907\u751f\u6210\u73af\u5883',project_ready:'\u9879\u76ee\u5df2\u5c31\u7eea',uploading_images:'\u4e0a\u4f20\u53c2\u8003\u56fe\u4e2d',solving_image_captcha:'\u56fe\u7247\u6253\u7801\u9a8c\u8bc1\u4e2d',submitting_image:'\u56fe\u7247\u63d0\u4ea4\u4e2d',image_generated:'\u56fe\u7247\u751f\u6210\u5b8c\u6210',preparing_video:'\u51c6\u5907\u89c6\u9891\u4efb\u52a1',submitting_video:'\u89c6\u9891\u63d0\u4ea4\u4e2d',video_submitted:'\u89c6\u9891\u4efb\u52a1\u5df2\u63d0\u4ea4',video_polling:'\u89c6\u9891\u751f\u6210\u4e2d',caching_image:'\u7f13\u5b58\u56fe\u7247\u4e2d',caching_video:'\u7f13\u5b58\u89c6\u9891\u4e2d',completed:'\u5df2\u5b8c\u6210',failed:'\u5931\u8d25',processing:'\u5904\u7406\u4e2d',upsampling_2k:'\u6b63\u5728\u653e\u5927\u52302K',upsampling_4k:'\u6b63\u5728\u653e\u5927\u52304K',upsampling_1080p:'\u6b63\u5728\u653e\u5927\u52301080P'};return map[statusText]||statusText}if(l.status_code===102)return'\u5904\u7406\u4e2d';if(l.status_code===200)return'\u5df2\u5b8c\u6210';if(l.status_code&&l.status_code>=400)return'\u5931\u8d25';return'-'}, formatLogStatusClass=l=>{const statusText=formatLogStatus(l);if(statusText==='\u5904\u7406\u4e2d')return'bg-amber-50 text-amber-700';if(statusText==='\u5df2\u5b8c\u6210')return'bg-green-50 text-green-700';if(statusText==='\u5931\u8d25')return'bg-red-50 text-red-700';return'bg-gray-100 text-gray-700'}, formatLogProgress=l=>{if(l.progress===null||l.progress===undefined||l.progress==='')return'-';const progress=Number(l.progress);return Number.isFinite(progress)?`${Math.max(0,Math.min(100,progress))}%`:'-'}, getLogOperationLabel=l=>{const operation=String(l&&l.operation||'').trim();if(operation==='generate_image')return'图片';if(operation==='generate_video')return'视频';return''}, diff --git a/tests/test_file_cache.py b/tests/test_file_cache.py new file mode 100644 index 0000000..ada3ac0 --- /dev/null +++ b/tests/test_file_cache.py @@ -0,0 +1,31 @@ +import asyncio +import os +import time + +from src.services.file_cache import FileCache + + +def test_cleanup_keeps_files_when_timeout_is_zero(tmp_path): + cache = FileCache(cache_dir=str(tmp_path), default_timeout=0) + cached_file = tmp_path / "expired.jpg" + cached_file.write_bytes(b"cached") + + expired_at = time.time() - 3600 + os.utime(cached_file, (expired_at, expired_at)) + + asyncio.run(cache._cleanup_expired_files()) + + assert cached_file.exists() + + +def test_cleanup_removes_files_when_timeout_is_positive(tmp_path): + cache = FileCache(cache_dir=str(tmp_path), default_timeout=1) + cached_file = tmp_path / "expired.jpg" + cached_file.write_bytes(b"cached") + + expired_at = time.time() - 3600 + os.utime(cached_file, (expired_at, expired_at)) + + asyncio.run(cache._cleanup_expired_files()) + + assert not cached_file.exists() diff --git a/tests/test_gemini_generate_content.py b/tests/test_gemini_generate_content.py index 467ae51..10d30a8 100644 --- a/tests/test_gemini_generate_content.py +++ b/tests/test_gemini_generate_content.py @@ -147,3 +147,28 @@ def test_models_generate_content_supports_system_instruction_and_file_data(clien assert response.json()["modelVersion"] == "gemini-3.1-flash-image" assert fake_handler.calls[0]["prompt"] == "answer in English\n\ndraw a cat" assert len(fake_handler.calls[0]["images"]) == 1 + + +def test_models_root_lists_gemini_models_for_newapi_compatibility(client): + response = client.get("/models") + + assert response.status_code == 200 + body = response.json() + names = {item["name"] for item in body["models"]} + + assert "models/gemini-3.0-pro-image" in names + assert "models/gemini-3.1-flash-image" in names + assert any( + item["name"] == "models/gemini-3.1-flash-image" + and item["supportedGenerationMethods"] == ["generateContent", "streamGenerateContent"] + for item in body["models"] + ) + + +def test_v1beta_models_root_lists_gemini_models(client): + response = client.get("/v1beta/models") + + assert response.status_code == 200 + body = response.json() + + assert any(item["name"] == "models/gemini-3.0-pro-image" for item in body["models"]) diff --git a/tests/test_generation_handler.py b/tests/test_generation_handler.py new file mode 100644 index 0000000..a4f068e --- /dev/null +++ b/tests/test_generation_handler.py @@ -0,0 +1,111 @@ +import asyncio +from types import SimpleNamespace + +from src.services.generation_handler import GenerationHandler + + +class FakeFlowClient: + async def upload_image(self, at, image_bytes, aspect_ratio, project_id=None): + return "media-uploaded" + + async def generate_image( + self, + at, + project_id, + prompt, + model_name, + aspect_ratio, + image_inputs=None, + token_id=None, + token_image_concurrency=None, + progress_callback=None, + ): + if progress_callback is not None: + await progress_callback("solving_image_captcha", 38) + await progress_callback("submitting_image", 48) + return ( + { + "media": [ + { + "name": "media-generated", + "image": { + "generatedImage": { + "fifeUrl": "https://example.com/generated.png" + } + }, + } + ] + }, + "session-1", + {"generation_attempts": [{"launch_queue_ms": 0, "launch_stagger_ms": 0}]}, + ) + + +class FakeDB: + def __init__(self): + self.status_updates = [] + + async def update_request_log(self, log_id, **kwargs): + self.status_updates.append( + { + "log_id": log_id, + "status_text": kwargs.get("status_text"), + "progress": kwargs.get("progress"), + } + ) + + +async def _collect(async_gen): + items = [] + async for item in async_gen: + items.append(item) + return items + + +def test_image_generation_progress_switches_from_upload_to_captcha(): + db = FakeDB() + handler = GenerationHandler( + flow_client=FakeFlowClient(), + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = SimpleNamespace( + id=1, + at="at-token", + image_concurrency=-1, + user_paygate_tier="PAYGATE_TIER_NOT_PAID", + ) + generation_result = handler._create_generation_result() + request_log_state = {"id": 123} + + asyncio.run( + _collect( + handler._handle_image_generation( + token=token, + project_id="project-1", + model_config={ + "model_name": "NARWHAL", + "aspect_ratio": "IMAGE_ASPECT_RATIO_SQUARE", + }, + prompt="draw a cat", + images=[b"fake-image"], + stream=False, + perf_trace={}, + generation_result=generation_result, + request_log_state=request_log_state, + pending_token_state={"active": False}, + ) + ) + ) + + status_texts = [item["status_text"] for item in db.status_updates] + + assert status_texts[:4] == [ + "uploading_images", + "solving_image_captcha", + "submitting_image", + "image_generated", + ]