mirror of
https://github.com/TheSmallHanCat/flow2api.git
synced 2026-05-10 07:45:51 +08:00
150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
import os
|
||
import json
|
||
import re
|
||
import base64
|
||
import aiohttp # Async test. Need to install
|
||
import asyncio
|
||
|
||
|
||
# --- 配置区域 ---
|
||
BASE_URL = os.getenv('GEMINI_FLOW2API_URL', 'http://127.0.0.1:8000')
|
||
BACKEND_URL = BASE_URL + "/v1/chat/completions"
|
||
API_KEY = os.getenv('GEMINI_FLOW2API_APIKEY', 'Bearer han1234')
|
||
if API_KEY is None:
|
||
raise ValueError('[gemini flow2api] api key not set')
|
||
MODEL_LANDSCAPE = "gemini-3.0-pro-image-landscape"
|
||
MODEL_PORTRAIT = "gemini-3.0-pro-image-portrait"
|
||
|
||
# 修改: 增加 model 参数,默认为 None
|
||
async def request_backend_generation(
|
||
prompt: str,
|
||
images: list[bytes] = None,
|
||
model: str = None) -> bytes | None:
|
||
"""
|
||
请求后端生成图片。
|
||
:param prompt: 提示词
|
||
:param images: 图片二进制列表
|
||
:param model: 指定模型名称 (可选)
|
||
:return: 成功返回图片bytes,失败返回None
|
||
"""
|
||
# 更新token
|
||
images = images or []
|
||
|
||
# 逻辑: 如果未指定 model,默认使用 Landscape
|
||
use_model = model if model else MODEL_LANDSCAPE
|
||
|
||
# 1. 构造 Payload
|
||
if images:
|
||
content_payload = [{"type": "text", "text": prompt}]
|
||
print(f"[Backend] 正在处理 {len(images)} 张图片输入...")
|
||
for img_bytes in images:
|
||
b64_str = base64.b64encode(img_bytes).decode('utf-8')
|
||
content_payload.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{b64_str}"}
|
||
})
|
||
else:
|
||
content_payload = prompt
|
||
|
||
payload = {
|
||
"model": use_model, # 使用选定的模型
|
||
"messages": [{"role": "user", "content": content_payload}],
|
||
"stream": True
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": API_KEY,
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
image_url = None
|
||
print(f"[Backend] Model: {use_model} | 发起请求: {prompt[:20]}...")
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(BACKEND_URL, json=payload, headers=headers, timeout=120) as response:
|
||
if response.status != 200:
|
||
err_text = await response.text()
|
||
content = response.content
|
||
print(f"[Backend Error] Status {response.status}: {err_text} {content}")
|
||
raise Exception(f"API Error: {response.status}: {err_text}")
|
||
|
||
async for line in response.content:
|
||
line_str = line.decode('utf-8').strip()
|
||
if line_str.startswith('{"error'):
|
||
chunk = json.loads(data_str)
|
||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||
msg = delta['reasoning_content']
|
||
if '401' in msg:
|
||
msg += '\nAccess Token 已失效,需重新配置。'
|
||
elif '400' in msg:
|
||
msg += '\n返回内容被拦截。'
|
||
raise Exception(msg)
|
||
|
||
if not line_str or not line_str.startswith('data: '):
|
||
continue
|
||
|
||
data_str = line_str[6:]
|
||
if data_str == '[DONE]':
|
||
break
|
||
|
||
try:
|
||
chunk = json.loads(data_str)
|
||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||
|
||
# 打印思考过程
|
||
if "reasoning_content" in delta:
|
||
print(delta['reasoning_content'], end="", flush=True)
|
||
|
||
# 提取内容中的图片链接
|
||
if "content" in delta:
|
||
content_text = delta["content"]
|
||
img_match = re.search(r'!\[.*?\]\((.*?)\)', content_text)
|
||
if img_match:
|
||
image_url = img_match.group(1)
|
||
print(f"\n[Backend] 捕获图片链接: {image_url}")
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 3. 下载生成的图片
|
||
if image_url:
|
||
async with session.get(image_url) as img_resp:
|
||
if img_resp.status == 200:
|
||
image_bytes = await img_resp.read()
|
||
return image_bytes
|
||
else:
|
||
print(f"[Backend Error] 图片下载失败: {img_resp.status}")
|
||
except Exception as e:
|
||
print(f"[Backend Exception] {e}")
|
||
raise e
|
||
|
||
return None
|
||
|
||
if __name__ == '__main__':
|
||
async def main():
|
||
print("=== AI 绘图接口测试 ===")
|
||
user_prompt = input("请输入提示词 (例如 '一只猫'): ").strip()
|
||
if not user_prompt:
|
||
user_prompt = "A cute cat in the garden"
|
||
|
||
print(f"正在请求: {user_prompt}")
|
||
|
||
# 这里的 images 传空列表用于测试文生图
|
||
# 如果想测试图生图,你需要手动读取本地文件:
|
||
# with open("output_test.jpg", "rb") as f: img_data = f.read()
|
||
# result = await request_backend_generation(user_prompt, [img_data])
|
||
|
||
result = await request_backend_generation(user_prompt)
|
||
|
||
if result:
|
||
filename = "output_test.jpg"
|
||
with open(filename, "wb") as f:
|
||
f.write(result)
|
||
print(f"\n[Success] 图片已保存为 {filename},大小: {len(result)} bytes")
|
||
else:
|
||
print("\n[Failed] 生成失败")
|
||
|
||
# 运行测试
|
||
if os.name == 'nt': # Windows 兼容性
|
||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||
asyncio.run(main()) |