修复外部资源地址返回0.0.0.0和默认8000端口问题

This commit is contained in:
shskjw
2026-03-31 00:27:37 +08:00
parent 3b76631659
commit ef4e9b3fa0
2 changed files with 97 additions and 22 deletions

View File

@@ -9,7 +9,7 @@ import re
from urllib.parse import urlparse
from curl_cffi.requests import AsyncSession
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import JSONResponse, StreamingResponse
from ..core.auth import verify_api_key_flexible
@@ -335,6 +335,19 @@ def _resolve_request_model(model: str, request: Any) -> str:
return resolved_model
def _get_request_base_url(request: Request) -> Optional[str]:
"""根据实际请求头推导对外可访问的基础地址。"""
forwarded_proto = (request.headers.get("x-forwarded-proto") or "").split(",")[0].strip()
forwarded_host = (request.headers.get("x-forwarded-host") or "").split(",")[0].strip()
host = (forwarded_host or request.headers.get("host") or "").strip()
if not host:
return None
proto = forwarded_proto or request.url.scheme or "http"
return f"{proto}://{host}"
async def _normalize_openai_request(
request: ChatCompletionRequest,
) -> NormalizedGenerationRequest:
@@ -385,6 +398,7 @@ async def _collect_non_stream_result(
model: str,
prompt: str,
images: List[bytes],
base_url_override: Optional[str] = None,
) -> str:
handler = _ensure_generation_handler()
result = None
@@ -393,6 +407,7 @@ async def _collect_non_stream_result(
prompt=prompt,
images=images if images else None,
stream=False,
base_url_override=base_url_override,
):
result = chunk
@@ -455,6 +470,33 @@ def _extract_openai_message_content(payload: Dict[str, Any]) -> str:
return content if isinstance(content, str) else ""
def _extract_url_from_openai_payload(payload: Dict[str, Any]) -> Optional[str]:
direct_url = payload.get("url")
if isinstance(direct_url, str) and direct_url.strip():
return direct_url.strip()
content = _extract_openai_message_content(payload).strip()
if not content:
return None
image_match = MARKDOWN_IMAGE_RE.search(content)
if image_match:
return image_match.group(1).strip()
video_match = HTML_VIDEO_RE.search(content)
if video_match:
return video_match.group(1).strip()
return None
def _enrich_payload_with_direct_url(payload: Dict[str, Any]) -> Dict[str, Any]:
extracted_url = _extract_url_from_openai_payload(payload)
if extracted_url and not payload.get("url"):
payload["url"] = extracted_url
return payload
async def _build_image_parts_from_uri(uri: str) -> List[Dict[str, Any]]:
if uri.startswith("data:image"):
mime_type, _ = _decode_data_url(uri)
@@ -585,6 +627,7 @@ async def _convert_openai_stream_chunk_to_gemini_event(
async def _iterate_openai_stream(
normalized: NormalizedGenerationRequest,
base_url_override: Optional[str] = None,
):
handler = _ensure_generation_handler()
async for chunk in handler.handle_generation(
@@ -592,6 +635,7 @@ async def _iterate_openai_stream(
prompt=normalized.prompt,
images=normalized.images if normalized.images else None,
stream=True,
base_url_override=base_url_override,
):
if chunk.startswith("data: "):
yield chunk
@@ -606,6 +650,7 @@ async def _iterate_openai_stream(
async def _iterate_gemini_stream(
normalized: NormalizedGenerationRequest,
response_model: str,
base_url_override: Optional[str] = None,
):
handler = _ensure_generation_handler()
async for chunk in handler.handle_generation(
@@ -613,6 +658,7 @@ async def _iterate_gemini_stream(
prompt=normalized.prompt,
images=normalized.images if normalized.images else None,
stream=True,
base_url_override=base_url_override,
):
if chunk.startswith("data: "):
payload_text = chunk[6:].strip()
@@ -713,6 +759,7 @@ async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_fle
@router.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
raw_request: Request,
api_key: str = Depends(verify_api_key_flexible),
):
"""OpenAI-compatible unified generation endpoint."""
@@ -721,9 +768,11 @@ async def create_chat_completion(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
request_base_url = _get_request_base_url(raw_request)
if request.stream:
return StreamingResponse(
_iterate_openai_stream(normalized),
_iterate_openai_stream(normalized, request_base_url),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -732,11 +781,14 @@ async def create_chat_completion(
},
)
payload = _parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
payload = _enrich_payload_with_direct_url(
_parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url,
)
)
)
return _build_openai_json_response(payload)
@@ -752,6 +804,7 @@ async def create_chat_completion(
async def generate_content(
model: str,
request: GeminiGenerateContentRequest,
raw_request: Request,
api_key: str = Depends(verify_api_key_flexible),
):
"""Gemini official generateContent endpoint."""
@@ -760,11 +813,16 @@ async def generate_content(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
payload = _parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url = _get_request_base_url(raw_request)
payload = _enrich_payload_with_direct_url(
_parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url,
)
)
)
if "error" in payload:
@@ -791,6 +849,7 @@ async def generate_content(
async def stream_generate_content(
model: str,
request: GeminiGenerateContentRequest,
raw_request: Request,
alt: Optional[str] = Query(None),
api_key: str = Depends(verify_api_key_flexible),
):
@@ -800,8 +859,10 @@ async def stream_generate_content(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
request_base_url = _get_request_base_url(raw_request)
return StreamingResponse(
_iterate_gemini_stream(normalized, model),
_iterate_gemini_stream(normalized, model, request_base_url),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",

View File

@@ -695,6 +695,7 @@ class GenerationHandler:
return {
"url": None,
"generated_assets": None,
"base_url": None,
}
def _mark_generation_failed(self, generation_result: Optional[Dict[str, Any]], error_message: str):
@@ -759,7 +760,8 @@ class GenerationHandler:
model: str,
prompt: str,
images: Optional[List[bytes]] = None,
stream: bool = False
stream: bool = False,
base_url_override: Optional[str] = None
) -> AsyncGenerator:
"""统一生成入口
@@ -781,6 +783,7 @@ class GenerationHandler:
}
generation_result = self._create_generation_result()
response_state = self._create_response_state()
response_state["base_url"] = (base_url_override or "").strip().rstrip("/") or None
request_log_state: Dict[str, Any] = {"id": None, "progress": 0}
# 防止并发链路复用到上一次请求的指纹上下文
@@ -1271,7 +1274,7 @@ class GenerationHandler:
if stream:
yield self._create_stream_chunk(f"缓存 {resolution_name} 图片中...\n")
cached_filename = await self.file_cache.cache_base64_image(encoded_image, resolution_name)
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
response_state["url"] = local_url
response_state["generated_assets"]["upscaled_image"]["local_url"] = local_url
response_state["generated_assets"]["upscaled_image"]["url"] = local_url
@@ -1351,7 +1354,7 @@ class GenerationHandler:
yield self._create_stream_chunk("正在缓存 1K 图片文件...\n")
try:
cached_filename = await self.file_cache.download_and_cache(image_url, "image")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
if stream:
yield self._create_stream_chunk("✅ 1K 图片缓存成功,准备返回缓存地址...\n")
except Exception as e:
@@ -1781,7 +1784,7 @@ class GenerationHandler:
if stream:
yield self._create_stream_chunk("正在缓存视频文件...\n")
cached_filename = await self.file_cache.download_and_cache(video_url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
if stream:
yield self._create_stream_chunk("✅ 视频缓存成功,准备返回缓存地址...\n")
except Exception as e:
@@ -1963,13 +1966,24 @@ class GenerationHandler:
return json.dumps(error, ensure_ascii=False)
def _get_base_url(self) -> str:
def _get_base_url(self, response_state: Optional[Dict[str, Any]] = None) -> str:
"""获取基础URL用于缓存文件访问"""
# 优先使用配置的cache_base_url
request_base_url = ""
if isinstance(response_state, dict):
request_base_url = (response_state.get("base_url") or "").strip().rstrip("/")
if request_base_url:
return request_base_url
# 优先使用配置的 cache_base_url
if config.cache_base_url:
return config.cache_base_url
# 否则使用服务器地址
return f"http://{config.server_host}:{config.server_port}"
return config.cache_base_url.rstrip("/")
# 回退到服务地址,避免把监听地址 0.0.0.0 / :: 直接返回给客户端
server_host = (config.server_host or "").strip()
if server_host in {"", "0.0.0.0", "::", "[::]"}:
server_host = "127.0.0.1"
return f"http://{server_host}:{config.server_port}"
async def _update_request_log_progress(
self,