diff --git a/src/api/routes.py b/src/api/routes.py index 6a7b8ca..1b345b7 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -1,13 +1,17 @@ """API routes - OpenAI compatible endpoints""" from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse -from typing import List +from typing import List, Optional import base64 import re import json +import time +from urllib.parse import urlparse +from curl_cffi.requests import AsyncSession from ..core.auth import verify_api_key_header from ..core.models import ChatCompletionRequest from ..services.generation_handler import GenerationHandler, MODEL_CONFIG +from ..core.logger import debug_logger router = APIRouter() @@ -21,6 +25,40 @@ def set_generation_handler(handler: GenerationHandler): generation_handler = handler +async def retrieve_image_data(url: str) -> Optional[bytes]: + """ + 智能获取图片数据: + 1. 优先检查是否为本地 /tmp/ 缓存文件,如果是则直接读取磁盘 + 2. 如果本地不存在或是外部链接,则进行网络下载 + """ + # 优先尝试本地读取 + try: + if "/tmp/" in url and generation_handler and generation_handler.file_cache: + path = urlparse(url).path + filename = path.split("/tmp/")[-1] + local_file_path = generation_handler.file_cache.cache_dir / filename + + if local_file_path.exists() and local_file_path.is_file(): + data = local_file_path.read_bytes() + if data: + return data + except Exception as e: + debug_logger.log_warning(f"[CONTEXT] 本地缓存读取失败: {str(e)}") + + # 回退逻辑:网络下载 + try: + async with AsyncSession() as session: + response = await session.get(url, timeout=30, impersonate="chrome110", verify=False) + if response.status_code == 200: + return response.content + else: + debug_logger.log_warning(f"[CONTEXT] 图片下载失败,状态码: {response.status_code}") + except Exception as e: + debug_logger.log_error(f"[CONTEXT] 图片下载异常: {str(e)}") + + return None + + @router.get("/v1/models") async def list_models(api_key: str = Depends(verify_api_key_header)): """List available models""" @@ -92,6 +130,33 @@ async def create_chat_completion( image_bytes = base64.b64decode(image_base64) images.append(image_bytes) + # 自动参考图:仅对图片模型生效 + model_config = MODEL_CONFIG.get(request.model) + + if model_config and model_config["type"] == "image" and not images and len(request.messages) > 1: + debug_logger.log_info(f"[CONTEXT] 开始查找历史参考图,消息数量: {len(request.messages)}") + + # 如果当前请求没有上传图片,则尝试从历史记录中寻找最近的一张生成图 + for msg in reversed(request.messages[:-1]): + if msg.role == "assistant" and isinstance(msg.content, str): + # 匹配 Markdown 图片格式: ![...](http...) + matches = re.findall(r"!\[.*?\]\((.*?)\)", msg.content) + if matches: + last_image_url = matches[-1] + + if last_image_url.startswith("http"): + try: + downloaded_bytes = await retrieve_image_data(last_image_url) + if downloaded_bytes and len(downloaded_bytes) > 0: + images.append(downloaded_bytes) + debug_logger.log_info(f"[CONTEXT] ✅ 自动使用历史参考图: {last_image_url}") + break + else: + debug_logger.log_warning(f"[CONTEXT] 图片下载失败或为空,尝试下一个: {last_image_url}") + except Exception as e: + debug_logger.log_error(f"[CONTEXT] 处理参考图时出错: {str(e)}") + # 继续尝试下一个图片 + if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty")