修复外部资源地址返回0.0.0.0和默认8000端口问题

This commit is contained in:
shskjw
2026-03-31 00:27:37 +08:00
parent 3b76631659
commit ef4e9b3fa0
2 changed files with 97 additions and 22 deletions

View File

@@ -9,7 +9,7 @@ import re
from urllib.parse import urlparse
from curl_cffi.requests import AsyncSession
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import JSONResponse, StreamingResponse
from ..core.auth import verify_api_key_flexible
@@ -335,6 +335,19 @@ def _resolve_request_model(model: str, request: Any) -> str:
return resolved_model
def _get_request_base_url(request: Request) -> Optional[str]:
"""根据实际请求头推导对外可访问的基础地址。"""
forwarded_proto = (request.headers.get("x-forwarded-proto") or "").split(",")[0].strip()
forwarded_host = (request.headers.get("x-forwarded-host") or "").split(",")[0].strip()
host = (forwarded_host or request.headers.get("host") or "").strip()
if not host:
return None
proto = forwarded_proto or request.url.scheme or "http"
return f"{proto}://{host}"
async def _normalize_openai_request(
request: ChatCompletionRequest,
) -> NormalizedGenerationRequest:
@@ -385,6 +398,7 @@ async def _collect_non_stream_result(
model: str,
prompt: str,
images: List[bytes],
base_url_override: Optional[str] = None,
) -> str:
handler = _ensure_generation_handler()
result = None
@@ -393,6 +407,7 @@ async def _collect_non_stream_result(
prompt=prompt,
images=images if images else None,
stream=False,
base_url_override=base_url_override,
):
result = chunk
@@ -455,6 +470,33 @@ def _extract_openai_message_content(payload: Dict[str, Any]) -> str:
return content if isinstance(content, str) else ""
def _extract_url_from_openai_payload(payload: Dict[str, Any]) -> Optional[str]:
direct_url = payload.get("url")
if isinstance(direct_url, str) and direct_url.strip():
return direct_url.strip()
content = _extract_openai_message_content(payload).strip()
if not content:
return None
image_match = MARKDOWN_IMAGE_RE.search(content)
if image_match:
return image_match.group(1).strip()
video_match = HTML_VIDEO_RE.search(content)
if video_match:
return video_match.group(1).strip()
return None
def _enrich_payload_with_direct_url(payload: Dict[str, Any]) -> Dict[str, Any]:
extracted_url = _extract_url_from_openai_payload(payload)
if extracted_url and not payload.get("url"):
payload["url"] = extracted_url
return payload
async def _build_image_parts_from_uri(uri: str) -> List[Dict[str, Any]]:
if uri.startswith("data:image"):
mime_type, _ = _decode_data_url(uri)
@@ -585,6 +627,7 @@ async def _convert_openai_stream_chunk_to_gemini_event(
async def _iterate_openai_stream(
normalized: NormalizedGenerationRequest,
base_url_override: Optional[str] = None,
):
handler = _ensure_generation_handler()
async for chunk in handler.handle_generation(
@@ -592,6 +635,7 @@ async def _iterate_openai_stream(
prompt=normalized.prompt,
images=normalized.images if normalized.images else None,
stream=True,
base_url_override=base_url_override,
):
if chunk.startswith("data: "):
yield chunk
@@ -606,6 +650,7 @@ async def _iterate_openai_stream(
async def _iterate_gemini_stream(
normalized: NormalizedGenerationRequest,
response_model: str,
base_url_override: Optional[str] = None,
):
handler = _ensure_generation_handler()
async for chunk in handler.handle_generation(
@@ -613,6 +658,7 @@ async def _iterate_gemini_stream(
prompt=normalized.prompt,
images=normalized.images if normalized.images else None,
stream=True,
base_url_override=base_url_override,
):
if chunk.startswith("data: "):
payload_text = chunk[6:].strip()
@@ -713,6 +759,7 @@ async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_fle
@router.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
raw_request: Request,
api_key: str = Depends(verify_api_key_flexible),
):
"""OpenAI-compatible unified generation endpoint."""
@@ -721,9 +768,11 @@ async def create_chat_completion(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
request_base_url = _get_request_base_url(raw_request)
if request.stream:
return StreamingResponse(
_iterate_openai_stream(normalized),
_iterate_openai_stream(normalized, request_base_url),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -732,11 +781,14 @@ async def create_chat_completion(
},
)
payload = _parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
payload = _enrich_payload_with_direct_url(
_parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url,
)
)
)
return _build_openai_json_response(payload)
@@ -752,6 +804,7 @@ async def create_chat_completion(
async def generate_content(
model: str,
request: GeminiGenerateContentRequest,
raw_request: Request,
api_key: str = Depends(verify_api_key_flexible),
):
"""Gemini official generateContent endpoint."""
@@ -760,11 +813,16 @@ async def generate_content(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
payload = _parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url = _get_request_base_url(raw_request)
payload = _enrich_payload_with_direct_url(
_parse_handler_result(
await _collect_non_stream_result(
normalized.model,
normalized.prompt,
normalized.images,
request_base_url,
)
)
)
if "error" in payload:
@@ -791,6 +849,7 @@ async def generate_content(
async def stream_generate_content(
model: str,
request: GeminiGenerateContentRequest,
raw_request: Request,
alt: Optional[str] = Query(None),
api_key: str = Depends(verify_api_key_flexible),
):
@@ -800,8 +859,10 @@ async def stream_generate_content(
if not normalized.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
request_base_url = _get_request_base_url(raw_request)
return StreamingResponse(
_iterate_gemini_stream(normalized, model),
_iterate_gemini_stream(normalized, model, request_base_url),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",