mirror of
https://github.com/TheSmallHanCat/flow2api.git
synced 2026-06-09 18:42:19 +08:00
fix: improve gemini compatibility and cache flow
1. add Gemini /models and /v1beta/models discovery endpoints for newapi compatibility 2. persist tmp cache in Docker and lock cache_timeout=0 behavior with regression tests 3. refine image generation status flow to show captcha verification after uploads
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -58,7 +58,7 @@ browser_data
|
||||
browser_data_rt
|
||||
|
||||
data
|
||||
tmp/
|
||||
config/setting.toml
|
||||
config/setting_warp.toml
|
||||
config/setting_warp_example.toml
|
||||
tmp/browser_pids/
|
||||
|
||||
@@ -55,6 +55,8 @@ docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
> 说明:Compose 已默认挂载 `./tmp:/app/tmp`。如果把缓存超时设为 `0`,语义是“不自动过期删除”;若希望容器重建后仍保留缓存文件,也需要保留这个 `tmp` 挂载。
|
||||
|
||||
#### WARP 模式(使用代理)
|
||||
|
||||
```bash
|
||||
|
||||
@@ -11,6 +11,7 @@ services:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./tmp:/app/tmp
|
||||
- ./config/setting.toml:/app/config/setting.toml
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
@@ -11,6 +11,7 @@ services:
|
||||
- "38000:8000"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./tmp:/app/tmp
|
||||
- ./config/setting.toml:/app/config/setting.toml
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
@@ -8,6 +8,7 @@ services:
|
||||
- "38000:8000"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./tmp:/app/tmp
|
||||
- ./config/setting_warp.toml:/app/config/setting.toml
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
@@ -8,6 +8,7 @@ services:
|
||||
- "38000:8000"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./tmp:/app/tmp
|
||||
- ./config/setting.toml:/app/config/setting.toml
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
@@ -67,6 +67,56 @@ def _ensure_generation_handler() -> GenerationHandler:
|
||||
return generation_handler
|
||||
|
||||
|
||||
def _build_model_description(model_config: Dict[str, Any]) -> str:
|
||||
"""Build a human-readable description for model listing endpoints."""
|
||||
description = f"{model_config['type'].capitalize()} generation"
|
||||
if model_config["type"] == "image":
|
||||
description += f" - {model_config['model_name']}"
|
||||
else:
|
||||
description += f" - {model_config['model_key']}"
|
||||
return description
|
||||
|
||||
|
||||
def _get_openai_model_catalog() -> List[Dict[str, str]]:
|
||||
"""Collect OpenAI-compatible model list entries."""
|
||||
return [
|
||||
{
|
||||
"id": model_id,
|
||||
"description": _build_model_description(model_config),
|
||||
}
|
||||
for model_id, model_config in MODEL_CONFIG.items()
|
||||
]
|
||||
|
||||
|
||||
def _get_gemini_model_catalog() -> Dict[str, str]:
|
||||
"""Collect Gemini-compatible model metadata for /models endpoints."""
|
||||
catalog: Dict[str, str] = {}
|
||||
|
||||
for alias_id, description in get_base_model_aliases().items():
|
||||
catalog[alias_id] = description
|
||||
|
||||
for model_id, model_config in MODEL_CONFIG.items():
|
||||
catalog.setdefault(model_id, _build_model_description(model_config))
|
||||
|
||||
return catalog
|
||||
|
||||
|
||||
def _build_gemini_model_resource(model_id: str, description: str) -> Dict[str, Any]:
|
||||
"""Build a Gemini-compatible model resource payload."""
|
||||
return {
|
||||
"name": f"models/{model_id}",
|
||||
"displayName": model_id,
|
||||
"description": description,
|
||||
"version": "flow2api",
|
||||
"inputTokenLimit": 0,
|
||||
"outputTokenLimit": 0,
|
||||
"supportedGenerationMethods": [
|
||||
"generateContent",
|
||||
"streamGenerateContent",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _decode_data_url(data_url: str) -> tuple[str, bytes]:
|
||||
match = DATA_URL_RE.match(data_url)
|
||||
if not match:
|
||||
@@ -601,23 +651,15 @@ async def _iterate_gemini_stream(
|
||||
@router.get("/v1/models")
|
||||
async def list_models(api_key: str = Depends(verify_api_key_flexible)):
|
||||
"""List available models."""
|
||||
models = []
|
||||
|
||||
for model_id, config in MODEL_CONFIG.items():
|
||||
description = f"{config['type'].capitalize()} generation"
|
||||
if config["type"] == "image":
|
||||
description += f" - {config['model_name']}"
|
||||
else:
|
||||
description += f" - {config['model_key']}"
|
||||
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"owned_by": "flow2api",
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
models = [
|
||||
{
|
||||
"id": model["id"],
|
||||
"object": "model",
|
||||
"owned_by": "flow2api",
|
||||
"description": model["description"],
|
||||
}
|
||||
for model in _get_openai_model_catalog()
|
||||
]
|
||||
|
||||
return {"object": "list", "data": models}
|
||||
|
||||
@@ -640,6 +682,34 @@ async def list_model_aliases(api_key: str = Depends(verify_api_key_flexible)):
|
||||
return {"object": "list", "data": alias_models}
|
||||
|
||||
|
||||
@router.get("/v1beta/models")
|
||||
@router.get("/models")
|
||||
async def list_gemini_models(api_key: str = Depends(verify_api_key_flexible)):
|
||||
"""List available models using Gemini-compatible response shape."""
|
||||
catalog = _get_gemini_model_catalog()
|
||||
return {
|
||||
"models": [
|
||||
_build_gemini_model_resource(model_id, description)
|
||||
for model_id, description in catalog.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1beta/models/{model}")
|
||||
@router.get("/models/{model}")
|
||||
async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_flexible)):
|
||||
"""Return a single model using Gemini-compatible response shape."""
|
||||
catalog = _get_gemini_model_catalog()
|
||||
description = catalog.get(model)
|
||||
if not description:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=_build_gemini_error_payload(404, f"Model not found: {model}"),
|
||||
)
|
||||
|
||||
return _build_gemini_model_resource(model, description)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
import random
|
||||
import base64
|
||||
import ssl
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from typing import Dict, Any, Optional, List, Union, Callable, Awaitable
|
||||
from urllib.parse import quote
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
@@ -895,6 +895,7 @@ class FlowClient:
|
||||
image_inputs: Optional[List[Dict]] = None,
|
||||
token_id: Optional[int] = None,
|
||||
token_image_concurrency: Optional[int] = None,
|
||||
progress_callback: Optional[Callable[[str, int], Awaitable[None]]] = None,
|
||||
) -> tuple[dict, str, Dict[str, Any]]:
|
||||
"""生成图片(同步返回)
|
||||
|
||||
@@ -930,6 +931,8 @@ class FlowClient:
|
||||
attempt_started_at = time.time()
|
||||
# 每次重试都重新获取 reCAPTCHA token
|
||||
recaptcha_started_at = time.time()
|
||||
if progress_callback is not None:
|
||||
await progress_callback("solving_image_captcha", 38)
|
||||
launch_gate_acquired = False
|
||||
launch_ok, launch_queue_ms, launch_stagger_ms = await self._acquire_image_launch_gate(
|
||||
token_id=token_id,
|
||||
@@ -973,6 +976,8 @@ class FlowClient:
|
||||
if should_retry:
|
||||
continue
|
||||
raise last_error
|
||||
if progress_callback is not None:
|
||||
await progress_callback("submitting_image", 48)
|
||||
session_id = self._generate_session_id()
|
||||
|
||||
# 构建请求 - 新版接口在外层和 requests 内都带 clientContext
|
||||
|
||||
@@ -1148,7 +1148,18 @@ class GenerationHandler:
|
||||
|
||||
# 调用生成API
|
||||
if stream:
|
||||
yield self._create_stream_chunk("正在生成图片...\n")
|
||||
if images and len(images) > 0:
|
||||
yield self._create_stream_chunk("参考图片上传完成,正在进行打码验证...\n")
|
||||
else:
|
||||
yield self._create_stream_chunk("正在进行打码验证并提交图片生成请求...\n")
|
||||
|
||||
async def _image_progress_callback(status_text: str, progress: int):
|
||||
await self._update_request_log_progress(
|
||||
request_log_state,
|
||||
token_id=token.id,
|
||||
status_text=status_text,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
generate_started_at = time.time()
|
||||
result, generation_session_id, upstream_trace = await self.flow_client.generate_image(
|
||||
@@ -1160,6 +1171,7 @@ class GenerationHandler:
|
||||
image_inputs=image_inputs,
|
||||
token_id=token.id,
|
||||
token_image_concurrency=token.image_concurrency,
|
||||
progress_callback=_image_progress_callback,
|
||||
)
|
||||
if image_trace is not None:
|
||||
image_trace["generate_api_ms"] = int((time.time() - generate_started_at) * 1000)
|
||||
@@ -1169,6 +1181,12 @@ class GenerationHandler:
|
||||
first_attempt = attempts[0] if isinstance(attempts[0], dict) else {}
|
||||
image_trace["launch_queue_wait_ms"] = int(first_attempt.get("launch_queue_ms") or 0)
|
||||
image_trace["launch_stagger_wait_ms"] = int(first_attempt.get("launch_stagger_ms") or 0)
|
||||
await self._update_request_log_progress(
|
||||
request_log_state,
|
||||
token_id=token.id,
|
||||
status_text="image_generated",
|
||||
progress=72,
|
||||
)
|
||||
|
||||
# 提取URL和mediaId
|
||||
media = result.get("media", [])
|
||||
@@ -1228,6 +1246,12 @@ class GenerationHandler:
|
||||
|
||||
if config.cache_enabled:
|
||||
try:
|
||||
await self._update_request_log_progress(
|
||||
request_log_state,
|
||||
token_id=token.id,
|
||||
status_text="caching_image",
|
||||
progress=90,
|
||||
)
|
||||
if stream:
|
||||
yield self._create_stream_chunk(f"缓存 {resolution_name} 图片中...\n")
|
||||
cached_filename = await self.file_cache.cache_base64_image(encoded_image, resolution_name)
|
||||
|
||||
@@ -827,7 +827,7 @@
|
||||
generateRandomToken=()=>{const chars='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';let token='';for(let i=0;i<32;i++){token+=chars.charAt(Math.floor(Math.random()*chars.length))}$('cfgPluginConnectionToken').value=token;showToast('随机Token已生成','success')},
|
||||
toggleATAutoRefresh=async()=>{try{const enabled=$('atAutoRefreshToggle').checked;const r=await apiRequest('/api/token-refresh/enabled',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r){$('atAutoRefreshToggle').checked=!enabled;return}const d=await r.json();if(d.success){showToast(enabled?'AT自动刷新已启用':'AT自动刷新已禁用','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('atAutoRefreshToggle').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('atAutoRefreshToggle').checked=!enabled}},
|
||||
loadATAutoRefreshConfig=async()=>{try{const r=await apiRequest('/api/token-refresh/config');if(!r)return;const d=await r.json();if(d.success&&d.config){$('atAutoRefreshToggle').checked=d.config.at_auto_refresh_enabled||false}else{console.error('AT自动刷新配置数据格式错误:',d)}}catch(e){console.error('加载AT自动刷新配置失败:',e)}},
|
||||
formatLogStatus=l=>{const statusText=(l.status_text||'').trim();if(statusText){const map={started:'\u5df2\u542f\u52a8',token_selected:'\u5df2\u9009\u4e2d\u8d26\u53f7',token_ready:'\u51c6\u5907\u751f\u6210\u73af\u5883',project_ready:'\u9879\u76ee\u5df2\u5c31\u7eea',uploading_images:'\u4e0a\u4f20\u53c2\u8003\u56fe\u4e2d',submitting_image:'\u56fe\u7247\u63d0\u4ea4\u4e2d',image_generated:'\u56fe\u7247\u751f\u6210\u5b8c\u6210',preparing_video:'\u51c6\u5907\u89c6\u9891\u4efb\u52a1',submitting_video:'\u89c6\u9891\u63d0\u4ea4\u4e2d',video_submitted:'\u89c6\u9891\u4efb\u52a1\u5df2\u63d0\u4ea4',video_polling:'\u89c6\u9891\u751f\u6210\u4e2d',caching_image:'\u7f13\u5b58\u56fe\u7247\u4e2d',caching_video:'\u7f13\u5b58\u89c6\u9891\u4e2d',completed:'\u5df2\u5b8c\u6210',failed:'\u5931\u8d25',processing:'\u5904\u7406\u4e2d',upsampling_2k:'\u6b63\u5728\u653e\u5927\u52302K',upsampling_4k:'\u6b63\u5728\u653e\u5927\u52304K',upsampling_1080p:'\u6b63\u5728\u653e\u5927\u52301080P'};return map[statusText]||statusText}if(l.status_code===102)return'\u5904\u7406\u4e2d';if(l.status_code===200)return'\u5df2\u5b8c\u6210';if(l.status_code&&l.status_code>=400)return'\u5931\u8d25';return'-'},
|
||||
formatLogStatus=l=>{const statusText=(l.status_text||'').trim();if(statusText){const map={started:'\u5df2\u542f\u52a8',token_selected:'\u5df2\u9009\u4e2d\u8d26\u53f7',token_ready:'\u51c6\u5907\u751f\u6210\u73af\u5883',project_ready:'\u9879\u76ee\u5df2\u5c31\u7eea',uploading_images:'\u4e0a\u4f20\u53c2\u8003\u56fe\u4e2d',solving_image_captcha:'\u56fe\u7247\u6253\u7801\u9a8c\u8bc1\u4e2d',submitting_image:'\u56fe\u7247\u63d0\u4ea4\u4e2d',image_generated:'\u56fe\u7247\u751f\u6210\u5b8c\u6210',preparing_video:'\u51c6\u5907\u89c6\u9891\u4efb\u52a1',submitting_video:'\u89c6\u9891\u63d0\u4ea4\u4e2d',video_submitted:'\u89c6\u9891\u4efb\u52a1\u5df2\u63d0\u4ea4',video_polling:'\u89c6\u9891\u751f\u6210\u4e2d',caching_image:'\u7f13\u5b58\u56fe\u7247\u4e2d',caching_video:'\u7f13\u5b58\u89c6\u9891\u4e2d',completed:'\u5df2\u5b8c\u6210',failed:'\u5931\u8d25',processing:'\u5904\u7406\u4e2d',upsampling_2k:'\u6b63\u5728\u653e\u5927\u52302K',upsampling_4k:'\u6b63\u5728\u653e\u5927\u52304K',upsampling_1080p:'\u6b63\u5728\u653e\u5927\u52301080P'};return map[statusText]||statusText}if(l.status_code===102)return'\u5904\u7406\u4e2d';if(l.status_code===200)return'\u5df2\u5b8c\u6210';if(l.status_code&&l.status_code>=400)return'\u5931\u8d25';return'-'},
|
||||
formatLogStatusClass=l=>{const statusText=formatLogStatus(l);if(statusText==='\u5904\u7406\u4e2d')return'bg-amber-50 text-amber-700';if(statusText==='\u5df2\u5b8c\u6210')return'bg-green-50 text-green-700';if(statusText==='\u5931\u8d25')return'bg-red-50 text-red-700';return'bg-gray-100 text-gray-700'},
|
||||
formatLogProgress=l=>{if(l.progress===null||l.progress===undefined||l.progress==='')return'-';const progress=Number(l.progress);return Number.isFinite(progress)?`${Math.max(0,Math.min(100,progress))}%`:'-'},
|
||||
getLogOperationLabel=l=>{const operation=String(l&&l.operation||'').trim();if(operation==='generate_image')return'图片';if(operation==='generate_video')return'视频';return''},
|
||||
|
||||
31
tests/test_file_cache.py
Normal file
31
tests/test_file_cache.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from src.services.file_cache import FileCache
|
||||
|
||||
|
||||
def test_cleanup_keeps_files_when_timeout_is_zero(tmp_path):
|
||||
cache = FileCache(cache_dir=str(tmp_path), default_timeout=0)
|
||||
cached_file = tmp_path / "expired.jpg"
|
||||
cached_file.write_bytes(b"cached")
|
||||
|
||||
expired_at = time.time() - 3600
|
||||
os.utime(cached_file, (expired_at, expired_at))
|
||||
|
||||
asyncio.run(cache._cleanup_expired_files())
|
||||
|
||||
assert cached_file.exists()
|
||||
|
||||
|
||||
def test_cleanup_removes_files_when_timeout_is_positive(tmp_path):
|
||||
cache = FileCache(cache_dir=str(tmp_path), default_timeout=1)
|
||||
cached_file = tmp_path / "expired.jpg"
|
||||
cached_file.write_bytes(b"cached")
|
||||
|
||||
expired_at = time.time() - 3600
|
||||
os.utime(cached_file, (expired_at, expired_at))
|
||||
|
||||
asyncio.run(cache._cleanup_expired_files())
|
||||
|
||||
assert not cached_file.exists()
|
||||
@@ -147,3 +147,28 @@ def test_models_generate_content_supports_system_instruction_and_file_data(clien
|
||||
assert response.json()["modelVersion"] == "gemini-3.1-flash-image"
|
||||
assert fake_handler.calls[0]["prompt"] == "answer in English\n\ndraw a cat"
|
||||
assert len(fake_handler.calls[0]["images"]) == 1
|
||||
|
||||
|
||||
def test_models_root_lists_gemini_models_for_newapi_compatibility(client):
|
||||
response = client.get("/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
names = {item["name"] for item in body["models"]}
|
||||
|
||||
assert "models/gemini-3.0-pro-image" in names
|
||||
assert "models/gemini-3.1-flash-image" in names
|
||||
assert any(
|
||||
item["name"] == "models/gemini-3.1-flash-image"
|
||||
and item["supportedGenerationMethods"] == ["generateContent", "streamGenerateContent"]
|
||||
for item in body["models"]
|
||||
)
|
||||
|
||||
|
||||
def test_v1beta_models_root_lists_gemini_models(client):
|
||||
response = client.get("/v1beta/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
|
||||
assert any(item["name"] == "models/gemini-3.0-pro-image" for item in body["models"])
|
||||
|
||||
111
tests/test_generation_handler.py
Normal file
111
tests/test_generation_handler.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.services.generation_handler import GenerationHandler
|
||||
|
||||
|
||||
class FakeFlowClient:
|
||||
async def upload_image(self, at, image_bytes, aspect_ratio, project_id=None):
|
||||
return "media-uploaded"
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
at,
|
||||
project_id,
|
||||
prompt,
|
||||
model_name,
|
||||
aspect_ratio,
|
||||
image_inputs=None,
|
||||
token_id=None,
|
||||
token_image_concurrency=None,
|
||||
progress_callback=None,
|
||||
):
|
||||
if progress_callback is not None:
|
||||
await progress_callback("solving_image_captcha", 38)
|
||||
await progress_callback("submitting_image", 48)
|
||||
return (
|
||||
{
|
||||
"media": [
|
||||
{
|
||||
"name": "media-generated",
|
||||
"image": {
|
||||
"generatedImage": {
|
||||
"fifeUrl": "https://example.com/generated.png"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"session-1",
|
||||
{"generation_attempts": [{"launch_queue_ms": 0, "launch_stagger_ms": 0}]},
|
||||
)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
def __init__(self):
|
||||
self.status_updates = []
|
||||
|
||||
async def update_request_log(self, log_id, **kwargs):
|
||||
self.status_updates.append(
|
||||
{
|
||||
"log_id": log_id,
|
||||
"status_text": kwargs.get("status_text"),
|
||||
"progress": kwargs.get("progress"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _collect(async_gen):
|
||||
items = []
|
||||
async for item in async_gen:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
def test_image_generation_progress_switches_from_upload_to_captcha():
|
||||
db = FakeDB()
|
||||
handler = GenerationHandler(
|
||||
flow_client=FakeFlowClient(),
|
||||
token_manager=None,
|
||||
load_balancer=None,
|
||||
db=db,
|
||||
concurrency_manager=None,
|
||||
proxy_manager=None,
|
||||
)
|
||||
token = SimpleNamespace(
|
||||
id=1,
|
||||
at="at-token",
|
||||
image_concurrency=-1,
|
||||
user_paygate_tier="PAYGATE_TIER_NOT_PAID",
|
||||
)
|
||||
generation_result = handler._create_generation_result()
|
||||
request_log_state = {"id": 123}
|
||||
|
||||
asyncio.run(
|
||||
_collect(
|
||||
handler._handle_image_generation(
|
||||
token=token,
|
||||
project_id="project-1",
|
||||
model_config={
|
||||
"model_name": "NARWHAL",
|
||||
"aspect_ratio": "IMAGE_ASPECT_RATIO_SQUARE",
|
||||
},
|
||||
prompt="draw a cat",
|
||||
images=[b"fake-image"],
|
||||
stream=False,
|
||||
perf_trace={},
|
||||
generation_result=generation_result,
|
||||
request_log_state=request_log_state,
|
||||
pending_token_state={"active": False},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
status_texts = [item["status_text"] for item in db.status_updates]
|
||||
|
||||
assert status_texts[:4] == [
|
||||
"uploading_images",
|
||||
"solving_image_captcha",
|
||||
"submitting_image",
|
||||
"image_generated",
|
||||
]
|
||||
Reference in New Issue
Block a user