diff --git a/src/api/routes.py b/src/api/routes.py index 087a764..e564423 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -9,7 +9,7 @@ import re from urllib.parse import urlparse from curl_cffi.requests import AsyncSession -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import JSONResponse, StreamingResponse from ..core.auth import verify_api_key_flexible @@ -335,6 +335,19 @@ def _resolve_request_model(model: str, request: Any) -> str: return resolved_model +def _get_request_base_url(request: Request) -> Optional[str]: + """根据实际请求头推导对外可访问的基础地址。""" + forwarded_proto = (request.headers.get("x-forwarded-proto") or "").split(",")[0].strip() + forwarded_host = (request.headers.get("x-forwarded-host") or "").split(",")[0].strip() + host = (forwarded_host or request.headers.get("host") or "").strip() + + if not host: + return None + + proto = forwarded_proto or request.url.scheme or "http" + return f"{proto}://{host}" + + async def _normalize_openai_request( request: ChatCompletionRequest, ) -> NormalizedGenerationRequest: @@ -385,6 +398,7 @@ async def _collect_non_stream_result( model: str, prompt: str, images: List[bytes], + base_url_override: Optional[str] = None, ) -> str: handler = _ensure_generation_handler() result = None @@ -393,6 +407,7 @@ async def _collect_non_stream_result( prompt=prompt, images=images if images else None, stream=False, + base_url_override=base_url_override, ): result = chunk @@ -455,6 +470,33 @@ def _extract_openai_message_content(payload: Dict[str, Any]) -> str: return content if isinstance(content, str) else "" +def _extract_url_from_openai_payload(payload: Dict[str, Any]) -> Optional[str]: + direct_url = payload.get("url") + if isinstance(direct_url, str) and direct_url.strip(): + return direct_url.strip() + + content = _extract_openai_message_content(payload).strip() + if not content: + return None + + image_match = MARKDOWN_IMAGE_RE.search(content) + if image_match: + return image_match.group(1).strip() + + video_match = HTML_VIDEO_RE.search(content) + if video_match: + return video_match.group(1).strip() + + return None + + +def _enrich_payload_with_direct_url(payload: Dict[str, Any]) -> Dict[str, Any]: + extracted_url = _extract_url_from_openai_payload(payload) + if extracted_url and not payload.get("url"): + payload["url"] = extracted_url + return payload + + async def _build_image_parts_from_uri(uri: str) -> List[Dict[str, Any]]: if uri.startswith("data:image"): mime_type, _ = _decode_data_url(uri) @@ -585,6 +627,7 @@ async def _convert_openai_stream_chunk_to_gemini_event( async def _iterate_openai_stream( normalized: NormalizedGenerationRequest, + base_url_override: Optional[str] = None, ): handler = _ensure_generation_handler() async for chunk in handler.handle_generation( @@ -592,6 +635,7 @@ async def _iterate_openai_stream( prompt=normalized.prompt, images=normalized.images if normalized.images else None, stream=True, + base_url_override=base_url_override, ): if chunk.startswith("data: "): yield chunk @@ -606,6 +650,7 @@ async def _iterate_openai_stream( async def _iterate_gemini_stream( normalized: NormalizedGenerationRequest, response_model: str, + base_url_override: Optional[str] = None, ): handler = _ensure_generation_handler() async for chunk in handler.handle_generation( @@ -613,6 +658,7 @@ async def _iterate_gemini_stream( prompt=normalized.prompt, images=normalized.images if normalized.images else None, stream=True, + base_url_override=base_url_override, ): if chunk.startswith("data: "): payload_text = chunk[6:].strip() @@ -713,6 +759,7 @@ async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_fle @router.post("/v1/chat/completions") async def create_chat_completion( request: ChatCompletionRequest, + raw_request: Request, api_key: str = Depends(verify_api_key_flexible), ): """OpenAI-compatible unified generation endpoint.""" @@ -721,9 +768,11 @@ async def create_chat_completion( if not normalized.prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") + request_base_url = _get_request_base_url(raw_request) + if request.stream: return StreamingResponse( - _iterate_openai_stream(normalized), + _iterate_openai_stream(normalized, request_base_url), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -732,11 +781,14 @@ async def create_chat_completion( }, ) - payload = _parse_handler_result( - await _collect_non_stream_result( - normalized.model, - normalized.prompt, - normalized.images, + payload = _enrich_payload_with_direct_url( + _parse_handler_result( + await _collect_non_stream_result( + normalized.model, + normalized.prompt, + normalized.images, + request_base_url, + ) ) ) return _build_openai_json_response(payload) @@ -752,6 +804,7 @@ async def create_chat_completion( async def generate_content( model: str, request: GeminiGenerateContentRequest, + raw_request: Request, api_key: str = Depends(verify_api_key_flexible), ): """Gemini official generateContent endpoint.""" @@ -760,11 +813,16 @@ async def generate_content( if not normalized.prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") - payload = _parse_handler_result( - await _collect_non_stream_result( - normalized.model, - normalized.prompt, - normalized.images, + request_base_url = _get_request_base_url(raw_request) + + payload = _enrich_payload_with_direct_url( + _parse_handler_result( + await _collect_non_stream_result( + normalized.model, + normalized.prompt, + normalized.images, + request_base_url, + ) ) ) if "error" in payload: @@ -791,6 +849,7 @@ async def generate_content( async def stream_generate_content( model: str, request: GeminiGenerateContentRequest, + raw_request: Request, alt: Optional[str] = Query(None), api_key: str = Depends(verify_api_key_flexible), ): @@ -800,8 +859,10 @@ async def stream_generate_content( if not normalized.prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") + request_base_url = _get_request_base_url(raw_request) + return StreamingResponse( - _iterate_gemini_stream(normalized, model), + _iterate_gemini_stream(normalized, model, request_base_url), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 132caa2..4fde8b5 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -695,6 +695,7 @@ class GenerationHandler: return { "url": None, "generated_assets": None, + "base_url": None, } def _mark_generation_failed(self, generation_result: Optional[Dict[str, Any]], error_message: str): @@ -759,7 +760,8 @@ class GenerationHandler: model: str, prompt: str, images: Optional[List[bytes]] = None, - stream: bool = False + stream: bool = False, + base_url_override: Optional[str] = None ) -> AsyncGenerator: """统一生成入口 @@ -781,6 +783,7 @@ class GenerationHandler: } generation_result = self._create_generation_result() response_state = self._create_response_state() + response_state["base_url"] = (base_url_override or "").strip().rstrip("/") or None request_log_state: Dict[str, Any] = {"id": None, "progress": 0} # 防止并发链路复用到上一次请求的指纹上下文 @@ -1271,7 +1274,7 @@ class GenerationHandler: if stream: yield self._create_stream_chunk(f"缓存 {resolution_name} 图片中...\n") cached_filename = await self.file_cache.cache_base64_image(encoded_image, resolution_name) - local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}" response_state["url"] = local_url response_state["generated_assets"]["upscaled_image"]["local_url"] = local_url response_state["generated_assets"]["upscaled_image"]["url"] = local_url @@ -1351,7 +1354,7 @@ class GenerationHandler: yield self._create_stream_chunk("正在缓存 1K 图片文件...\n") try: cached_filename = await self.file_cache.download_and_cache(image_url, "image") - local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}" if stream: yield self._create_stream_chunk("✅ 1K 图片缓存成功,准备返回缓存地址...\n") except Exception as e: @@ -1781,7 +1784,7 @@ class GenerationHandler: if stream: yield self._create_stream_chunk("正在缓存视频文件...\n") cached_filename = await self.file_cache.download_and_cache(video_url, "video") - local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}" if stream: yield self._create_stream_chunk("✅ 视频缓存成功,准备返回缓存地址...\n") except Exception as e: @@ -1963,13 +1966,24 @@ class GenerationHandler: return json.dumps(error, ensure_ascii=False) - def _get_base_url(self) -> str: + def _get_base_url(self, response_state: Optional[Dict[str, Any]] = None) -> str: """获取基础URL用于缓存文件访问""" - # 优先使用配置的cache_base_url + request_base_url = "" + if isinstance(response_state, dict): + request_base_url = (response_state.get("base_url") or "").strip().rstrip("/") + if request_base_url: + return request_base_url + + # 优先使用配置的 cache_base_url if config.cache_base_url: - return config.cache_base_url - # 否则使用服务器地址 - return f"http://{config.server_host}:{config.server_port}" + return config.cache_base_url.rstrip("/") + + # 回退到服务地址,避免把监听地址 0.0.0.0 / :: 直接返回给客户端 + server_host = (config.server_host or "").strip() + if server_host in {"", "0.0.0.0", "::", "[::]"}: + server_host = "127.0.0.1" + + return f"http://{server_host}:{config.server_port}" async def _update_request_log_progress( self,