mirror of
https://github.com/TheSmallHanCat/flow2api.git
synced 2026-05-07 22:43:16 +08:00
修复外部资源地址返回0.0.0.0和默认8000端口问题
This commit is contained in:
@@ -9,7 +9,7 @@ import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from curl_cffi.requests import AsyncSession
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from ..core.auth import verify_api_key_flexible
|
||||
@@ -335,6 +335,19 @@ def _resolve_request_model(model: str, request: Any) -> str:
|
||||
return resolved_model
|
||||
|
||||
|
||||
def _get_request_base_url(request: Request) -> Optional[str]:
|
||||
"""根据实际请求头推导对外可访问的基础地址。"""
|
||||
forwarded_proto = (request.headers.get("x-forwarded-proto") or "").split(",")[0].strip()
|
||||
forwarded_host = (request.headers.get("x-forwarded-host") or "").split(",")[0].strip()
|
||||
host = (forwarded_host or request.headers.get("host") or "").strip()
|
||||
|
||||
if not host:
|
||||
return None
|
||||
|
||||
proto = forwarded_proto or request.url.scheme or "http"
|
||||
return f"{proto}://{host}"
|
||||
|
||||
|
||||
async def _normalize_openai_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> NormalizedGenerationRequest:
|
||||
@@ -385,6 +398,7 @@ async def _collect_non_stream_result(
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: List[bytes],
|
||||
base_url_override: Optional[str] = None,
|
||||
) -> str:
|
||||
handler = _ensure_generation_handler()
|
||||
result = None
|
||||
@@ -393,6 +407,7 @@ async def _collect_non_stream_result(
|
||||
prompt=prompt,
|
||||
images=images if images else None,
|
||||
stream=False,
|
||||
base_url_override=base_url_override,
|
||||
):
|
||||
result = chunk
|
||||
|
||||
@@ -455,6 +470,33 @@ def _extract_openai_message_content(payload: Dict[str, Any]) -> str:
|
||||
return content if isinstance(content, str) else ""
|
||||
|
||||
|
||||
def _extract_url_from_openai_payload(payload: Dict[str, Any]) -> Optional[str]:
|
||||
direct_url = payload.get("url")
|
||||
if isinstance(direct_url, str) and direct_url.strip():
|
||||
return direct_url.strip()
|
||||
|
||||
content = _extract_openai_message_content(payload).strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
image_match = MARKDOWN_IMAGE_RE.search(content)
|
||||
if image_match:
|
||||
return image_match.group(1).strip()
|
||||
|
||||
video_match = HTML_VIDEO_RE.search(content)
|
||||
if video_match:
|
||||
return video_match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _enrich_payload_with_direct_url(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
extracted_url = _extract_url_from_openai_payload(payload)
|
||||
if extracted_url and not payload.get("url"):
|
||||
payload["url"] = extracted_url
|
||||
return payload
|
||||
|
||||
|
||||
async def _build_image_parts_from_uri(uri: str) -> List[Dict[str, Any]]:
|
||||
if uri.startswith("data:image"):
|
||||
mime_type, _ = _decode_data_url(uri)
|
||||
@@ -585,6 +627,7 @@ async def _convert_openai_stream_chunk_to_gemini_event(
|
||||
|
||||
async def _iterate_openai_stream(
|
||||
normalized: NormalizedGenerationRequest,
|
||||
base_url_override: Optional[str] = None,
|
||||
):
|
||||
handler = _ensure_generation_handler()
|
||||
async for chunk in handler.handle_generation(
|
||||
@@ -592,6 +635,7 @@ async def _iterate_openai_stream(
|
||||
prompt=normalized.prompt,
|
||||
images=normalized.images if normalized.images else None,
|
||||
stream=True,
|
||||
base_url_override=base_url_override,
|
||||
):
|
||||
if chunk.startswith("data: "):
|
||||
yield chunk
|
||||
@@ -606,6 +650,7 @@ async def _iterate_openai_stream(
|
||||
async def _iterate_gemini_stream(
|
||||
normalized: NormalizedGenerationRequest,
|
||||
response_model: str,
|
||||
base_url_override: Optional[str] = None,
|
||||
):
|
||||
handler = _ensure_generation_handler()
|
||||
async for chunk in handler.handle_generation(
|
||||
@@ -613,6 +658,7 @@ async def _iterate_gemini_stream(
|
||||
prompt=normalized.prompt,
|
||||
images=normalized.images if normalized.images else None,
|
||||
stream=True,
|
||||
base_url_override=base_url_override,
|
||||
):
|
||||
if chunk.startswith("data: "):
|
||||
payload_text = chunk[6:].strip()
|
||||
@@ -713,6 +759,7 @@ async def get_gemini_model(model: str, api_key: str = Depends(verify_api_key_fle
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
api_key: str = Depends(verify_api_key_flexible),
|
||||
):
|
||||
"""OpenAI-compatible unified generation endpoint."""
|
||||
@@ -721,9 +768,11 @@ async def create_chat_completion(
|
||||
if not normalized.prompt:
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
request_base_url = _get_request_base_url(raw_request)
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
_iterate_openai_stream(normalized),
|
||||
_iterate_openai_stream(normalized, request_base_url),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
@@ -732,11 +781,14 @@ async def create_chat_completion(
|
||||
},
|
||||
)
|
||||
|
||||
payload = _parse_handler_result(
|
||||
await _collect_non_stream_result(
|
||||
normalized.model,
|
||||
normalized.prompt,
|
||||
normalized.images,
|
||||
payload = _enrich_payload_with_direct_url(
|
||||
_parse_handler_result(
|
||||
await _collect_non_stream_result(
|
||||
normalized.model,
|
||||
normalized.prompt,
|
||||
normalized.images,
|
||||
request_base_url,
|
||||
)
|
||||
)
|
||||
)
|
||||
return _build_openai_json_response(payload)
|
||||
@@ -752,6 +804,7 @@ async def create_chat_completion(
|
||||
async def generate_content(
|
||||
model: str,
|
||||
request: GeminiGenerateContentRequest,
|
||||
raw_request: Request,
|
||||
api_key: str = Depends(verify_api_key_flexible),
|
||||
):
|
||||
"""Gemini official generateContent endpoint."""
|
||||
@@ -760,11 +813,16 @@ async def generate_content(
|
||||
if not normalized.prompt:
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
payload = _parse_handler_result(
|
||||
await _collect_non_stream_result(
|
||||
normalized.model,
|
||||
normalized.prompt,
|
||||
normalized.images,
|
||||
request_base_url = _get_request_base_url(raw_request)
|
||||
|
||||
payload = _enrich_payload_with_direct_url(
|
||||
_parse_handler_result(
|
||||
await _collect_non_stream_result(
|
||||
normalized.model,
|
||||
normalized.prompt,
|
||||
normalized.images,
|
||||
request_base_url,
|
||||
)
|
||||
)
|
||||
)
|
||||
if "error" in payload:
|
||||
@@ -791,6 +849,7 @@ async def generate_content(
|
||||
async def stream_generate_content(
|
||||
model: str,
|
||||
request: GeminiGenerateContentRequest,
|
||||
raw_request: Request,
|
||||
alt: Optional[str] = Query(None),
|
||||
api_key: str = Depends(verify_api_key_flexible),
|
||||
):
|
||||
@@ -800,8 +859,10 @@ async def stream_generate_content(
|
||||
if not normalized.prompt:
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
request_base_url = _get_request_base_url(raw_request)
|
||||
|
||||
return StreamingResponse(
|
||||
_iterate_gemini_stream(normalized, model),
|
||||
_iterate_gemini_stream(normalized, model, request_base_url),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
|
||||
@@ -695,6 +695,7 @@ class GenerationHandler:
|
||||
return {
|
||||
"url": None,
|
||||
"generated_assets": None,
|
||||
"base_url": None,
|
||||
}
|
||||
|
||||
def _mark_generation_failed(self, generation_result: Optional[Dict[str, Any]], error_message: str):
|
||||
@@ -759,7 +760,8 @@ class GenerationHandler:
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: Optional[List[bytes]] = None,
|
||||
stream: bool = False
|
||||
stream: bool = False,
|
||||
base_url_override: Optional[str] = None
|
||||
) -> AsyncGenerator:
|
||||
"""统一生成入口
|
||||
|
||||
@@ -781,6 +783,7 @@ class GenerationHandler:
|
||||
}
|
||||
generation_result = self._create_generation_result()
|
||||
response_state = self._create_response_state()
|
||||
response_state["base_url"] = (base_url_override or "").strip().rstrip("/") or None
|
||||
request_log_state: Dict[str, Any] = {"id": None, "progress": 0}
|
||||
|
||||
# 防止并发链路复用到上一次请求的指纹上下文
|
||||
@@ -1271,7 +1274,7 @@ class GenerationHandler:
|
||||
if stream:
|
||||
yield self._create_stream_chunk(f"缓存 {resolution_name} 图片中...\n")
|
||||
cached_filename = await self.file_cache.cache_base64_image(encoded_image, resolution_name)
|
||||
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
|
||||
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
|
||||
response_state["url"] = local_url
|
||||
response_state["generated_assets"]["upscaled_image"]["local_url"] = local_url
|
||||
response_state["generated_assets"]["upscaled_image"]["url"] = local_url
|
||||
@@ -1351,7 +1354,7 @@ class GenerationHandler:
|
||||
yield self._create_stream_chunk("正在缓存 1K 图片文件...\n")
|
||||
try:
|
||||
cached_filename = await self.file_cache.download_and_cache(image_url, "image")
|
||||
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
|
||||
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
|
||||
if stream:
|
||||
yield self._create_stream_chunk("✅ 1K 图片缓存成功,准备返回缓存地址...\n")
|
||||
except Exception as e:
|
||||
@@ -1781,7 +1784,7 @@ class GenerationHandler:
|
||||
if stream:
|
||||
yield self._create_stream_chunk("正在缓存视频文件...\n")
|
||||
cached_filename = await self.file_cache.download_and_cache(video_url, "video")
|
||||
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
|
||||
local_url = f"{self._get_base_url(response_state)}/tmp/{cached_filename}"
|
||||
if stream:
|
||||
yield self._create_stream_chunk("✅ 视频缓存成功,准备返回缓存地址...\n")
|
||||
except Exception as e:
|
||||
@@ -1963,13 +1966,24 @@ class GenerationHandler:
|
||||
|
||||
return json.dumps(error, ensure_ascii=False)
|
||||
|
||||
def _get_base_url(self) -> str:
|
||||
def _get_base_url(self, response_state: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""获取基础URL用于缓存文件访问"""
|
||||
# 优先使用配置的cache_base_url
|
||||
request_base_url = ""
|
||||
if isinstance(response_state, dict):
|
||||
request_base_url = (response_state.get("base_url") or "").strip().rstrip("/")
|
||||
if request_base_url:
|
||||
return request_base_url
|
||||
|
||||
# 优先使用配置的 cache_base_url
|
||||
if config.cache_base_url:
|
||||
return config.cache_base_url
|
||||
# 否则使用服务器地址
|
||||
return f"http://{config.server_host}:{config.server_port}"
|
||||
return config.cache_base_url.rstrip("/")
|
||||
|
||||
# 回退到服务地址,避免把监听地址 0.0.0.0 / :: 直接返回给客户端
|
||||
server_host = (config.server_host or "").strip()
|
||||
if server_host in {"", "0.0.0.0", "::", "[::]"}:
|
||||
server_host = "127.0.0.1"
|
||||
|
||||
return f"http://{server_host}:{config.server_port}"
|
||||
|
||||
async def _update_request_log_progress(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user