From 36cbcc0d8b29ff0ea8244bc448f4eebdb3e488e6 Mon Sep 17 00:00:00 2001 From: TheSmallHanCat Date: Mon, 24 Nov 2025 18:27:05 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9Aflow2api=E5=88=9D=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .dockerignore | 77 +++ .gitignore | 54 ++ Dockerfile | 12 + README.md | 261 ++++++++- config/setting.toml | 35 ++ config/setting_warp.toml | 35 ++ docker-compose.proxy.yml | 36 ++ docker-compose.yml | 14 + main.py | 13 + requirements.txt | 9 + src/api/__init__.py | 6 + src/api/admin.py | 669 +++++++++++++++++++++ src/api/routes.py | 147 +++++ src/core/__init__.py | 7 + src/core/auth.py | 39 ++ src/core/config.py | 183 ++++++ src/core/database.py | 879 ++++++++++++++++++++++++++++ src/core/logger.py | 243 ++++++++ src/core/models.py | 145 +++++ src/main.py | 162 +++++ src/services/__init__.py | 17 + src/services/concurrency_manager.py | 190 ++++++ src/services/file_cache.py | 199 +++++++ src/services/flow_client.py | 657 +++++++++++++++++++++ src/services/generation_handler.py | 850 +++++++++++++++++++++++++++ src/services/load_balancer.py | 87 +++ src/services/proxy_manager.py | 25 + src/services/token_manager.py | 384 ++++++++++++ static/login.html | 53 ++ static/manage.html | 586 +++++++++++++++++++ 30 files changed, 6073 insertions(+), 1 deletion(-) create mode 100644 .dockerignore create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 config/setting.toml create mode 100644 config/setting_warp.toml create mode 100644 docker-compose.proxy.yml create mode 100644 docker-compose.yml create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 src/api/__init__.py create mode 100644 src/api/admin.py create mode 100644 src/api/routes.py create mode 100644 src/core/__init__.py create mode 100644 src/core/auth.py create mode 100644 src/core/config.py create mode 100644 src/core/database.py create mode 100644 src/core/logger.py create mode 100644 src/core/models.py create mode 100644 src/main.py create mode 100644 src/services/__init__.py create mode 100644 src/services/concurrency_manager.py create mode 100644 src/services/file_cache.py create mode 100644 src/services/flow_client.py create mode 100644 src/services/generation_handler.py create mode 100644 src/services/load_balancer.py create mode 100644 src/services/proxy_manager.py create mode 100644 src/services/token_manager.py create mode 100644 static/login.html create mode 100644 static/manage.html diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8d10335 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,77 @@ +# Git +.git +.gitignore +.gitattributes + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +*.manifest +*.spec +pip-log.txt +pip-delete-this-directory.txt + +# Virtual Environment +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Project specific +data/*.db +data/*.db-journal +tmp/* +logs/* +*.log + +# Docker +Dockerfile +docker-compose*.yml +.dockerignore + +# Documentation +README.md +DEPLOYMENT.md +LICENSE +*.md + +# Test files +tests/ +test_*.py +*_test.py + +# CI/CD +.github/ +.gitlab-ci.yml +.travis.yml + +# Environment files +.env +.env.* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..710398a --- /dev/null +++ b/.gitignore @@ -0,0 +1,54 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +venv/ +ENV/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Database +*.db +*.sqlite +*.sqlite3 +data/*.db + +# Logs +*.log +logs.txt + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Environment +.env +.env.local + +# Config (optional) +# config/setting.toml + +# Temporary files +*.tmp +*.bak +*.cache diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d340750 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 8000 + +CMD ["python", "main.py"] diff --git a/README.md b/README.md index 49550f3..36ec53d 100644 --- a/README.md +++ b/README.md @@ -1 +1,260 @@ -# flow2api \ No newline at end of file +# Flow2API + +
+ +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) +[![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/) +[![FastAPI](https://img.shields.io/badge/fastapi-0.119.0-green.svg)](https://fastapi.tiangolo.com/) +[![Docker](https://img.shields.io/badge/docker-supported-blue.svg)](https://www.docker.com/) + +**一个功能完整的 OpenAI 兼容 API 服务,为 Flow 提供统一的接口** + +
+ +## ✨ 核心特性 + +- 🎨 **文生图** / **图生图** +- 🎬 **文生视频** / **图生视频** +- 🎞️ **首尾帧视频** +- 🔄 **AT自动刷新** +- 📊 **余额显示** - 实时查询和显示 VideoFX Credits +- 🚀 **负载均衡** - 多 Token 轮询和并发控制 +- 🌐 **代理支持** - 支持 HTTP/SOCKS5 代理 +- 📱 **Web 管理界面** - 直观的 Token 和配置管理 + +## 🚀 快速开始 + +### 前置要求 + +- Docker 和 Docker Compose(推荐) +- 或 Python 3.8+ + +### 方式一:Docker 部署(推荐) + +#### 标准模式(不使用代理) + +```bash +# 克隆项目 +git clone https://github.com/TheSmallHanCat/flow2api.git +cd sora2api + +# 启动服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f +``` + +#### WARP 模式(使用代理) + +```bash +# 使用 WARP 代理启动 +docker-compose -f docker-compose.warp.yml up -d + +# 查看日志 +docker-compose -f docker-compose.warp.yml logs -f +``` + +### 方式二:本地部署 + +```bash +# 克隆项目 +git clone https://github.com/TheSmallHanCat/flow2api.git +cd sora2api + +# 创建虚拟环境 +python -m venv venv + +# 激活虚拟环境 +# Windows +venv\Scripts\activate +# Linux/Mac +source venv/bin/activate + +# 安装依赖 +pip install -r requirements.txt + +# 启动服务 +python main.py +``` + +### 首次访问 + +服务启动后,访问管理后台: **http://localhost:8000** + +- **用户名**: `admin` +- **密码**: `admin` + +⚠️ **重要**: 首次登录后请立即修改密码! + +## 📋 支持的模型 + +### 图片生成 + +| 模型名称 | 说明| 尺寸 | +|---------|--------|--------| +| `gemini-2.5-flash-image-landscape` | 图/文生图 | 横屏 | +| `gemini-2.5-flash-image-portrait` | 图/文生图 | 竖屏 | +| `gemini-3.0-pro-image-landscape` | 图/文生图 | 横屏 | +| `gemini-3.0-pro-image-portrait` | 图/文生图 | 竖屏 | +| `imagen-4.0-generate-preview-landscape` | 图/文生图 | 横屏 | +| `imagen-4.0-generate-preview-portrait` | 图/文生图 | 竖屏 | + +### 视频生成 + +#### 文生视频 (T2V - Text to Video) +⚠️ **不支持上传图片** + +| 模型名称 | 说明| 尺寸 | +|---------|---------|--------| +| `veo_3_1_t2v_fast_portrait` | 文生视频 | 竖屏 | +| `veo_3_1_t2v_fast_landscape` | 文生视频 | 横屏 | +| `veo_2_1_fast_d_15_t2v_portrait` | 文生视频 | 竖屏 | +| `veo_2_1_fast_d_15_t2v_landscape` | 文生视频 | 横屏 | +| `veo_2_0_t2v_portrait` | 文生视频 | 竖屏 | +| `veo_2_0_t2v_landscape` | 文生视频 | 横屏 | + +#### 首尾帧模型 (I2V - Image to Video) +📸 **支持1-2张图片:首尾帧** + +| 模型名称 | 说明| 尺寸 | +|---------|---------|--------| +| `veo_3_1_i2v_s_fast_fl_portrait` | 图生视频 | 竖屏 | +| `veo_3_1_i2v_s_fast_fl_landscape` | 图生视频 | 横屏 | +| `veo_2_1_fast_d_15_i2v_portrait` | 图生视频 | 竖屏 | +| `veo_2_1_fast_d_15_i2v_landscape` | 图生视频 | 横屏 | +| `veo_2_0_i2v_portrait` | 图生视频 | 竖屏 | +| `veo_2_0_i2v_landscape` | 图生视频 | 横屏 | + +#### 多图生成 (R2V - Reference Images to Video) +🖼️ **支持多张图片** + +| 模型名称 | 说明| 尺寸 | +|---------|---------|--------| +| `veo_3_0_r2v_fast_portrait` | 图生视频 | 竖屏 | +| `veo_3_0_r2v_fast_landscape` | 图生视频 | 横屏 | + +## 📡 API 使用示例(需要使用流式) + +### 文生图 + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-2.5-flash-image-landscape", + "messages": [ + { + "role": "user", + "content": "一只可爱的猫咪在花园里玩耍" + } + ], + "stream": true + }' +``` + +### 图生图 + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "imagen-4.0-generate-preview-landscape", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "将这张图片变成水彩画风格" + }, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64," + } + } + ] + } + ], + "stream": true + }' +``` + +### 文生视频 + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "veo_3_1_t2v_fast_landscape", + "messages": [ + { + "role": "user", + "content": "一只小猫在草地上追逐蝴蝶" + } + ], + "stream": true + }' +``` + +### 首尾帧生成视频 + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "veo_3_1_i2v_s_fast_fl_landscape", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "从第一张图过渡到第二张图" + }, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,<首帧base64>" + } + }, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,<尾帧base64>" + } + } + ] + } + ], + "stream": true + }' +``` + +--- + +## 📄 许可证 + +本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。 + +--- + +## 🙏 致谢 + +感谢所有贡献者和使用者的支持! + +--- + +## 📞 联系方式 + +- 提交 Issue:[GitHub Issues](https://github.com/TheSmallHanCat/flow2api/issues) +- 讨论:[GitHub Discussions](https://github.com/TheSmallHanCat/flow2api/discussions) + +--- + +**⭐ 如果这个项目对你有帮助,请给个 Star!** diff --git a/config/setting.toml b/config/setting.toml new file mode 100644 index 0000000..1571506 --- /dev/null +++ b/config/setting.toml @@ -0,0 +1,35 @@ +[global] +api_key = "han1234" +admin_username = "admin" +admin_password = "admin" + +[flow] +labs_base_url = "https://labs.google/fx/api" +api_base_url = "https://aisandbox-pa.googleapis.com/v1" +timeout = 120 +max_retries = 3 +poll_interval = 3.0 +max_poll_attempts = 200 + +[server] +host = "0.0.0.0" +port = 8000 + +[debug] +enabled = false +log_requests = true +log_responses = true +mask_token = true + +[proxy] +proxy_enabled = false +proxy_url = "" + +[generation] +image_timeout = 300 +video_timeout = 1500 + +[cache] +enabled = false +timeout = 7200 # 缓存超时时间(秒), 默认2小时 +base_url = "" # 缓存文件访问的基础URL, 留空则使用服务器地址 diff --git a/config/setting_warp.toml b/config/setting_warp.toml new file mode 100644 index 0000000..4c806d1 --- /dev/null +++ b/config/setting_warp.toml @@ -0,0 +1,35 @@ +[global] +api_key = "han1234" +admin_username = "admin" +admin_password = "admin" + +[flow] +labs_base_url = "https://labs.google/fx/api" +api_base_url = "https://aisandbox-pa.googleapis.com/v1" +timeout = 120 +max_retries = 3 +poll_interval = 3.0 +max_poll_attempts = 200 + +[server] +host = "0.0.0.0" +port = 8000 + +[debug] +enabled = false +log_requests = true +log_responses = true +mask_token = true + +[proxy] +proxy_enabled = true +proxy_url = "socks5://warp:1080" + +[generation] +image_timeout = 300 +video_timeout = 1500 + +[cache] +enabled = false +timeout = 7200 # 缓存超时时间(秒), 默认2小时 +base_url = "" # 缓存文件访问的基础URL, 留空则使用服务器地址 diff --git a/docker-compose.proxy.yml b/docker-compose.proxy.yml new file mode 100644 index 0000000..2748871 --- /dev/null +++ b/docker-compose.proxy.yml @@ -0,0 +1,36 @@ +version: '3.8' + +services: + flow2api: + image: thesmallhancat/flow2api:latest + container_name: flow2api + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./config/setting_warp.toml:/app/config/setting.toml + environment: + - PYTHONUNBUFFERED=1 + restart: unless-stopped + depends_on: + - warp + + warp: + image: caomingjun/warp + container_name: warp + restart: always + devices: + - /dev/net/tun:/dev/net/tun + ports: + - "1080:1080" + environment: + - WARP_SLEEP=2 + cap_add: + - MKNOD + - AUDIT_WRITE + - NET_ADMIN + sysctls: + - net.ipv6.conf.all.disable_ipv6=0 + - net.ipv4.conf.all.src_valid_mark=1 + volumes: + - ./data:/var/lib/cloudflare-warp \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..7cab3d4 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,14 @@ +version: '3.8' + +services: + flow2api: + image: thesmallhancat/flow2api:latest + container_name: flow2api + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./config/setting.toml:/app/config/setting.toml + environment: + - PYTHONUNBUFFERED=1 + restart: unless-stopped diff --git a/main.py b/main.py new file mode 100644 index 0000000..be04a86 --- /dev/null +++ b/main.py @@ -0,0 +1,13 @@ +"""Flow2API - Main Entry Point""" +from src.main import app +import uvicorn + +if __name__ == "__main__": + from src.core.config import config + + uvicorn.run( + "src.main:app", + host=config.server_host, + port=config.server_port, + reload=False + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5bcaaf1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +fastapi==0.119.0 +uvicorn[standard]==0.32.1 +aiosqlite==0.20.0 +pydantic==2.10.4 +curl-cffi==0.7.3 +tomli==2.2.1 +bcrypt==4.2.1 +python-multipart==0.0.20 +python-dateutil==2.8.2 diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..a4ce6fa --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,6 @@ +"""API modules""" + +from .routes import router as api_router +from .admin import router as admin_router + +__all__ = ["api_router", "admin_router"] diff --git a/src/api/admin.py b/src/api/admin.py new file mode 100644 index 0000000..e15855f --- /dev/null +++ b/src/api/admin.py @@ -0,0 +1,669 @@ +"""Admin API routes""" +from fastapi import APIRouter, Depends, HTTPException, Header +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from typing import Optional, List +from ..core.auth import AuthManager +from ..core.database import Database +from ..services.token_manager import TokenManager +from ..services.proxy_manager import ProxyManager + +router = APIRouter() + +# Dependency injection +token_manager: TokenManager = None +proxy_manager: ProxyManager = None +db: Database = None + + +def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database): + """Set service instances""" + global token_manager, proxy_manager, db + token_manager = tm + proxy_manager = pm + db = database + + +# ========== Request Models ========== + +class LoginRequest(BaseModel): + username: str + password: str + + +class AddTokenRequest(BaseModel): + st: str + project_id: Optional[str] = None # 用户可选输入project_id + project_name: Optional[str] = None + remark: Optional[str] = None + image_enabled: bool = True + video_enabled: bool = True + image_concurrency: int = -1 + video_concurrency: int = -1 + + +class UpdateTokenRequest(BaseModel): + st: str # Session Token (必填,用于刷新AT) + project_id: Optional[str] = None # 用户可选输入project_id + project_name: Optional[str] = None + remark: Optional[str] = None + image_enabled: Optional[bool] = None + video_enabled: Optional[bool] = None + image_concurrency: Optional[int] = None + video_concurrency: Optional[int] = None + + +class ProxyConfigRequest(BaseModel): + proxy_enabled: bool + proxy_url: Optional[str] = None + + +class GenerationConfigRequest(BaseModel): + image_timeout: int + video_timeout: int + + +class ChangePasswordRequest(BaseModel): + old_password: str + new_password: str + + +class UpdateAPIKeyRequest(BaseModel): + new_api_key: str + + +class UpdateDebugConfigRequest(BaseModel): + enabled: bool + + +class ST2ATRequest(BaseModel): + """ST转AT请求""" + st: str + + +# ========== Auth Middleware ========== + +async def verify_admin_token(authorization: str = Header(None)): + """Verify admin token""" + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing authorization") + + token = authorization[7:] + admin_config = await db.get_admin_config() + + # Simple token verification: check if matches api_key + if token != admin_config.api_key: + raise HTTPException(status_code=401, detail="Invalid admin token") + + return token + + +# ========== Auth Endpoints ========== + +@router.post("/api/admin/login") +async def admin_login(request: LoginRequest): + """Admin login""" + admin_config = await db.get_admin_config() + + if not AuthManager.verify_admin(request.username, request.password): + raise HTTPException(status_code=401, detail="Invalid credentials") + + return { + "success": True, + "token": admin_config.api_key, + "username": admin_config.username + } + + +@router.post("/api/admin/change-password") +async def change_password( + request: ChangePasswordRequest, + token: str = Depends(verify_admin_token) +): + """Change admin password""" + admin_config = await db.get_admin_config() + + # Verify old password + if not AuthManager.verify_admin(admin_config.username, request.old_password): + raise HTTPException(status_code=400, detail="旧密码错误") + + # Update password + await db.update_admin_config(password=request.new_password) + + return {"success": True, "message": "密码修改成功"} + + +# ========== Token Management ========== + +@router.get("/api/tokens") +async def get_tokens(token: str = Depends(verify_admin_token)): + """Get all tokens with statistics""" + tokens = await token_manager.get_all_tokens() + result = [] + + for t in tokens: + stats = await db.get_token_stats(t.id) + + result.append({ + "id": t.id, + "st": t.st, # Session Token for editing + "at": t.at, # Access Token for editing (从ST转换而来) + "at_expires": t.at_expires.isoformat() if t.at_expires else None, # 🆕 AT过期时间 + "token": t.at, # 兼容前端 token.token 的访问方式 + "email": t.email, + "name": t.name, + "remark": t.remark, + "is_active": t.is_active, + "created_at": t.created_at.isoformat() if t.created_at else None, + "last_used_at": t.last_used_at.isoformat() if t.last_used_at else None, + "use_count": t.use_count, + "credits": t.credits, # 🆕 余额 + "user_paygate_tier": t.user_paygate_tier, + "current_project_id": t.current_project_id, # 🆕 项目ID + "current_project_name": t.current_project_name, # 🆕 项目名称 + "image_enabled": t.image_enabled, + "video_enabled": t.video_enabled, + "image_concurrency": t.image_concurrency, + "video_concurrency": t.video_concurrency, + "image_count": stats.image_count if stats else 0, + "video_count": stats.video_count if stats else 0, + "error_count": stats.error_count if stats else 0 + }) + + return result # 直接返回数组,兼容前端 + + +@router.post("/api/tokens") +async def add_token( + request: AddTokenRequest, + token: str = Depends(verify_admin_token) +): + """Add a new token""" + try: + new_token = await token_manager.add_token( + st=request.st, + project_id=request.project_id, # 🆕 支持用户指定project_id + project_name=request.project_name, + remark=request.remark, + image_enabled=request.image_enabled, + video_enabled=request.video_enabled, + image_concurrency=request.image_concurrency, + video_concurrency=request.video_concurrency + ) + + return { + "success": True, + "message": "Token添加成功", + "token": { + "id": new_token.id, + "email": new_token.email, + "credits": new_token.credits, + "project_id": new_token.current_project_id, + "project_name": new_token.current_project_name + } + } + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"添加Token失败: {str(e)}") + + +@router.put("/api/tokens/{token_id}") +async def update_token( + token_id: int, + request: UpdateTokenRequest, + token: str = Depends(verify_admin_token) +): + """Update token - 使用ST自动刷新AT""" + try: + # 先ST转AT + result = await token_manager.flow_client.st_to_at(request.st) + at = result["access_token"] + expires = result.get("expires") + + # 解析过期时间 + from datetime import datetime + at_expires = None + if expires: + try: + at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) + except: + pass + + # 更新token (包含AT、ST、AT过期时间、project_id和project_name) + await token_manager.update_token( + token_id=token_id, + st=request.st, + at=at, + at_expires=at_expires, # 🆕 更新AT过期时间 + project_id=request.project_id, + project_name=request.project_name, + remark=request.remark, + image_enabled=request.image_enabled, + video_enabled=request.video_enabled, + image_concurrency=request.image_concurrency, + video_concurrency=request.video_concurrency + ) + + return {"success": True, "message": "Token更新成功"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/api/tokens/{token_id}") +async def delete_token( + token_id: int, + token: str = Depends(verify_admin_token) +): + """Delete token""" + try: + await token_manager.delete_token(token_id) + return {"success": True, "message": "Token删除成功"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/api/tokens/{token_id}/enable") +async def enable_token( + token_id: int, + token: str = Depends(verify_admin_token) +): + """Enable token""" + await token_manager.enable_token(token_id) + return {"success": True, "message": "Token已启用"} + + +@router.post("/api/tokens/{token_id}/disable") +async def disable_token( + token_id: int, + token: str = Depends(verify_admin_token) +): + """Disable token""" + await token_manager.disable_token(token_id) + return {"success": True, "message": "Token已禁用"} + + +@router.post("/api/tokens/{token_id}/refresh-credits") +async def refresh_credits( + token_id: int, + token: str = Depends(verify_admin_token) +): + """刷新Token余额 🆕""" + try: + credits = await token_manager.refresh_credits(token_id) + return { + "success": True, + "message": "余额刷新成功", + "credits": credits + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"刷新余额失败: {str(e)}") + + +@router.post("/api/tokens/{token_id}/refresh-at") +async def refresh_at( + token_id: int, + token: str = Depends(verify_admin_token) +): + """手动刷新Token的AT (使用ST转换) 🆕""" + try: + # 调用token_manager的内部刷新方法 + success = await token_manager._refresh_at(token_id) + + if success: + # 获取更新后的token信息 + updated_token = await token_manager.get_token(token_id) + return { + "success": True, + "message": "AT刷新成功", + "token": { + "id": updated_token.id, + "email": updated_token.email, + "at_expires": updated_token.at_expires.isoformat() if updated_token.at_expires else None + } + } + else: + raise HTTPException(status_code=500, detail="AT刷新失败") + except Exception as e: + raise HTTPException(status_code=500, detail=f"刷新AT失败: {str(e)}") + + +@router.post("/api/tokens/st2at") +async def st_to_at( + request: ST2ATRequest, + token: str = Depends(verify_admin_token) +): + """Convert Session Token to Access Token (仅转换,不添加到数据库)""" + try: + result = await token_manager.flow_client.st_to_at(request.st) + return { + "success": True, + "message": "ST converted to AT successfully", + "access_token": result["access_token"], + "email": result.get("user", {}).get("email"), + "expires": result.get("expires") + } + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +# ========== Config Management ========== + +@router.get("/api/config/proxy") +async def get_proxy_config(token: str = Depends(verify_admin_token)): + """Get proxy configuration""" + config = await proxy_manager.get_proxy_config() + return { + "success": True, + "config": { + "enabled": config.enabled, + "proxy_url": config.proxy_url + } + } + + +@router.get("/api/proxy/config") +async def get_proxy_config_alias(token: str = Depends(verify_admin_token)): + """Get proxy configuration (alias for frontend compatibility)""" + config = await proxy_manager.get_proxy_config() + return { + "proxy_enabled": config.enabled, # Frontend expects proxy_enabled + "proxy_url": config.proxy_url + } + + +@router.post("/api/proxy/config") +async def update_proxy_config_alias( + request: ProxyConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update proxy configuration (alias for frontend compatibility)""" + await proxy_manager.update_proxy_config(request.proxy_enabled, request.proxy_url) + return {"success": True, "message": "代理配置更新成功"} + + +@router.post("/api/config/proxy") +async def update_proxy_config( + request: ProxyConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update proxy configuration""" + await proxy_manager.update_proxy_config(request.proxy_enabled, request.proxy_url) + return {"success": True, "message": "代理配置更新成功"} + + +@router.get("/api/config/generation") +async def get_generation_config(token: str = Depends(verify_admin_token)): + """Get generation timeout configuration""" + config = await db.get_generation_config() + return { + "success": True, + "config": { + "image_timeout": config.image_timeout, + "video_timeout": config.video_timeout + } + } + + +@router.post("/api/config/generation") +async def update_generation_config( + request: GenerationConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update generation timeout configuration""" + await db.update_generation_config(request.image_timeout, request.video_timeout) + return {"success": True, "message": "生成配置更新成功"} + + +# ========== System Info ========== + +@router.get("/api/system/info") +async def get_system_info(token: str = Depends(verify_admin_token)): + """Get system information""" + tokens = await token_manager.get_all_tokens() + active_tokens = [t for t in tokens if t.is_active] + + total_credits = sum(t.credits for t in active_tokens) + + return { + "success": True, + "info": { + "total_tokens": len(tokens), + "active_tokens": len(active_tokens), + "total_credits": total_credits, + "version": "1.0.0" + } + } + + +# ========== Additional Routes for Frontend Compatibility ========== + +@router.post("/api/login") +async def login(request: LoginRequest): + """Login endpoint (alias for /api/admin/login)""" + return await admin_login(request) + + +@router.get("/api/stats") +async def get_stats(token: str = Depends(verify_admin_token)): + """Get statistics for dashboard""" + tokens = await token_manager.get_all_tokens() + active_tokens = [t for t in tokens if t.is_active] + + # Calculate totals + total_images = 0 + total_videos = 0 + total_errors = 0 + today_images = 0 + today_videos = 0 + today_errors = 0 + + for t in tokens: + stats = await db.get_token_stats(t.id) + if stats: + total_images += stats.image_count + total_videos += stats.video_count + total_errors += stats.error_count + today_images += stats.today_image_count + today_videos += stats.today_video_count + today_errors += stats.today_error_count + + return { + "total_tokens": len(tokens), + "active_tokens": len(active_tokens), + "total_images": total_images, + "total_videos": total_videos, + "total_errors": total_errors, + "today_images": today_images, + "today_videos": today_videos, + "today_errors": today_errors + } + + +@router.get("/api/logs") +async def get_logs( + limit: int = 100, + token: str = Depends(verify_admin_token) +): + """Get request logs with token email""" + logs = await db.get_logs(limit=limit) + + return [{ + "id": log.get("id"), + "token_id": log.get("token_id"), + "token_email": log.get("token_email"), + "token_username": log.get("token_username"), + "operation": log.get("operation"), + "status_code": log.get("status_code"), + "duration": log.get("duration"), + "created_at": log.get("created_at") + } for log in logs] + + +@router.get("/api/admin/config") +async def get_admin_config(token: str = Depends(verify_admin_token)): + """Get admin configuration""" + from ..core.config import config + + admin_config = await db.get_admin_config() + + return { + "admin_username": admin_config.username, + "api_key": admin_config.api_key, + "error_ban_threshold": 3, # Default value + "debug_enabled": config.debug_enabled # Return actual debug status + } + + +@router.post("/api/admin/password") +async def update_admin_password( + request: ChangePasswordRequest, + token: str = Depends(verify_admin_token) +): + """Update admin password""" + return await change_password(request, token) + + +@router.post("/api/admin/apikey") +async def update_api_key( + request: UpdateAPIKeyRequest, + token: str = Depends(verify_admin_token) +): + """Update API key""" + await db.update_admin_config(api_key=request.new_api_key) + return {"success": True, "message": "API Key更新成功"} + + +@router.post("/api/admin/debug") +async def update_debug_config( + request: UpdateDebugConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update debug configuration""" + try: + # Import config instance + from ..core.config import config + + # Update in-memory config + config.set_debug_enabled(request.enabled) + + status = "enabled" if request.enabled else "disabled" + return {"success": True, "message": f"Debug mode {status}", "enabled": request.enabled} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update debug config: {str(e)}") + + +@router.get("/api/generation/timeout") +async def get_generation_timeout(token: str = Depends(verify_admin_token)): + """Get generation timeout configuration""" + return await get_generation_config(token) + + +@router.post("/api/generation/timeout") +async def update_generation_timeout( + request: GenerationConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update generation timeout configuration""" + return await update_generation_config(request, token) + + +# ========== AT Auto Refresh Config ========== + +@router.get("/api/token-refresh/config") +async def get_token_refresh_config(token: str = Depends(verify_admin_token)): + """Get AT auto refresh configuration (默认启用)""" + return { + "success": True, + "config": { + "at_auto_refresh_enabled": True # Flow2API默认启用AT自动刷新 + } + } + + +@router.post("/api/token-refresh/enabled") +async def update_token_refresh_enabled( + token: str = Depends(verify_admin_token) +): + """Update AT auto refresh enabled (Flow2API固定启用,此接口仅用于前端兼容)""" + return { + "success": True, + "message": "Flow2API的AT自动刷新默认启用且无法关闭" + } + + +# ========== Cache Configuration Endpoints ========== + +@router.get("/api/cache/config") +async def get_cache_config(token: str = Depends(verify_admin_token)): + """Get cache configuration""" + cache_config = await db.get_cache_config() + + # Calculate effective base URL + effective_base_url = cache_config.cache_base_url if cache_config.cache_base_url else f"http://127.0.0.1:8000" + + return { + "success": True, + "config": { + "enabled": cache_config.cache_enabled, + "timeout": cache_config.cache_timeout, + "base_url": cache_config.cache_base_url or "", + "effective_base_url": effective_base_url + } + } + + +@router.post("/api/cache/enabled") +async def update_cache_enabled( + request: dict, + token: str = Depends(verify_admin_token) +): + """Update cache enabled status""" + enabled = request.get("enabled", False) + await db.update_cache_config(enabled=enabled) + + # Update runtime config + from ..core.config import config + config.set_cache_enabled(enabled) + + return {"success": True, "message": f"缓存已{'启用' if enabled else '禁用'}"} + + +@router.post("/api/cache/config") +async def update_cache_config_full( + request: dict, + token: str = Depends(verify_admin_token) +): + """Update complete cache configuration""" + enabled = request.get("enabled") + timeout = request.get("timeout") + base_url = request.get("base_url") + + await db.update_cache_config(enabled=enabled, timeout=timeout, base_url=base_url) + + # Update runtime config + from ..core.config import config + if enabled is not None: + config.set_cache_enabled(enabled) + if timeout is not None: + config.set_cache_timeout(timeout) + if base_url is not None: + config.set_cache_base_url(base_url) + + return {"success": True, "message": "缓存配置更新成功"} + + +@router.post("/api/cache/base-url") +async def update_cache_base_url( + request: dict, + token: str = Depends(verify_admin_token) +): + """Update cache base URL""" + base_url = request.get("base_url", "") + await db.update_cache_config(base_url=base_url) + + # Update runtime config + from ..core.config import config + config.set_cache_base_url(base_url) + + return {"success": True, "message": "缓存Base URL更新成功"} diff --git a/src/api/routes.py b/src/api/routes.py new file mode 100644 index 0000000..6a7b8ca --- /dev/null +++ b/src/api/routes.py @@ -0,0 +1,147 @@ +"""API routes - OpenAI compatible endpoints""" +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +from typing import List +import base64 +import re +import json +from ..core.auth import verify_api_key_header +from ..core.models import ChatCompletionRequest +from ..services.generation_handler import GenerationHandler, MODEL_CONFIG + +router = APIRouter() + +# Dependency injection will be set up in main.py +generation_handler: GenerationHandler = None + + +def set_generation_handler(handler: GenerationHandler): + """Set generation handler instance""" + global generation_handler + generation_handler = handler + + +@router.get("/v1/models") +async def list_models(api_key: str = Depends(verify_api_key_header)): + """List available models""" + models = [] + + for model_id, config in MODEL_CONFIG.items(): + description = f"{config['type'].capitalize()} generation" + if config['type'] == 'image': + description += f" - {config['model_name']}" + else: + description += f" - {config['model_key']}" + + models.append({ + "id": model_id, + "object": "model", + "owned_by": "flow2api", + "description": description + }) + + return { + "object": "list", + "data": models + } + + +@router.post("/v1/chat/completions") +async def create_chat_completion( + request: ChatCompletionRequest, + api_key: str = Depends(verify_api_key_header) +): + """Create chat completion (unified endpoint for image and video generation)""" + try: + # Extract prompt from messages + if not request.messages: + raise HTTPException(status_code=400, detail="Messages cannot be empty") + + last_message = request.messages[-1] + content = last_message.content + + # Handle both string and array format (OpenAI multimodal) + prompt = "" + images: List[bytes] = [] + + if isinstance(content, str): + # Simple text format + prompt = content + elif isinstance(content, list): + # Multimodal format + for item in content: + if item.get("type") == "text": + prompt = item.get("text", "") + elif item.get("type") == "image_url": + # Extract base64 image + image_url = item.get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + # Parse base64 + match = re.search(r"base64,(.+)", image_url) + if match: + image_base64 = match.group(1) + image_bytes = base64.b64decode(image_base64) + images.append(image_bytes) + + # Fallback to deprecated image parameter + if request.image and not images: + if request.image.startswith("data:image"): + match = re.search(r"base64,(.+)", request.image) + if match: + image_base64 = match.group(1) + image_bytes = base64.b64decode(image_base64) + images.append(image_bytes) + + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + # Call generation handler + if request.stream: + # Streaming response + async def generate(): + async for chunk in generation_handler.handle_generation( + model=request.model, + prompt=prompt, + images=images if images else None, + stream=True + ): + yield chunk + + # Send [DONE] signal + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + else: + # Non-streaming response + result = None + async for chunk in generation_handler.handle_generation( + model=request.model, + prompt=prompt, + images=images if images else None, + stream=False + ): + result = chunk + + if result: + # Parse the result JSON string + try: + result_json = json.loads(result) + return JSONResponse(content=result_json) + except json.JSONDecodeError: + # If not JSON, return as-is + return JSONResponse(content={"result": result}) + else: + raise HTTPException(status_code=500, detail="Generation failed: No response from handler") + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..db89fee --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,7 @@ +"""Core modules""" + +from .config import config +from .auth import AuthManager, verify_api_key_header +from .logger import debug_logger + +__all__ = ["config", "AuthManager", "verify_api_key_header", "debug_logger"] diff --git a/src/core/auth.py b/src/core/auth.py new file mode 100644 index 0000000..0568573 --- /dev/null +++ b/src/core/auth.py @@ -0,0 +1,39 @@ +"""Authentication module""" +import bcrypt +from typing import Optional +from fastapi import HTTPException, Security +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from .config import config + +security = HTTPBearer() + +class AuthManager: + """Authentication manager""" + + @staticmethod + def verify_api_key(api_key: str) -> bool: + """Verify API key""" + return api_key == config.api_key + + @staticmethod + def verify_admin(username: str, password: str) -> bool: + """Verify admin credentials""" + # Compare with current config (which may be from database or config file) + return username == config.admin_username and password == config.admin_password + + @staticmethod + def hash_password(password: str) -> str: + """Hash password""" + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + @staticmethod + def verify_password(password: str, hashed: str) -> bool: + """Verify password""" + return bcrypt.checkpw(password.encode(), hashed.encode()) + +async def verify_api_key_header(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: + """Verify API key from Authorization header""" + api_key = credentials.credentials + if not AuthManager.verify_api_key(api_key): + raise HTTPException(status_code=401, detail="Invalid API key") + return api_key diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..bc85058 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,183 @@ +"""Configuration management for Flow2API""" +import tomli +from pathlib import Path +from typing import Dict, Any, Optional + +class Config: + """Application configuration""" + + def __init__(self): + self._config = self._load_config() + self._admin_username: Optional[str] = None + self._admin_password: Optional[str] = None + + def _load_config(self) -> Dict[str, Any]: + """Load configuration from setting.toml""" + config_path = Path(__file__).parent.parent.parent / "config" / "setting.toml" + with open(config_path, "rb") as f: + return tomli.load(f) + + def reload_config(self): + """Reload configuration from file""" + self._config = self._load_config() + + def get_raw_config(self) -> Dict[str, Any]: + """Get raw configuration dictionary""" + return self._config + + @property + def admin_username(self) -> str: + # If admin_username is set from database, use it; otherwise fall back to config file + if self._admin_username is not None: + return self._admin_username + return self._config["global"]["admin_username"] + + @admin_username.setter + def admin_username(self, value: str): + self._admin_username = value + self._config["global"]["admin_username"] = value + + def set_admin_username_from_db(self, username: str): + """Set admin username from database""" + self._admin_username = username + + # Flow2API specific properties + @property + def flow_labs_base_url(self) -> str: + """Google Labs base URL for project management""" + return self._config["flow"]["labs_base_url"] + + @property + def flow_api_base_url(self) -> str: + """Google AI Sandbox API base URL for generation""" + return self._config["flow"]["api_base_url"] + + @property + def flow_timeout(self) -> int: + return self._config["flow"]["timeout"] + + @property + def flow_max_retries(self) -> int: + return self._config["flow"]["max_retries"] + + @property + def poll_interval(self) -> float: + return self._config["flow"]["poll_interval"] + + @property + def max_poll_attempts(self) -> int: + return self._config["flow"]["max_poll_attempts"] + + @property + def server_host(self) -> str: + return self._config["server"]["host"] + + @property + def server_port(self) -> int: + return self._config["server"]["port"] + + @property + def debug_enabled(self) -> bool: + return self._config.get("debug", {}).get("enabled", False) + + @property + def debug_log_requests(self) -> bool: + return self._config.get("debug", {}).get("log_requests", True) + + @property + def debug_log_responses(self) -> bool: + return self._config.get("debug", {}).get("log_responses", True) + + @property + def debug_mask_token(self) -> bool: + return self._config.get("debug", {}).get("mask_token", True) + + # Mutable properties for runtime updates + @property + def api_key(self) -> str: + return self._config["global"]["api_key"] + + @api_key.setter + def api_key(self, value: str): + self._config["global"]["api_key"] = value + + @property + def admin_password(self) -> str: + # If admin_password is set from database, use it; otherwise fall back to config file + if self._admin_password is not None: + return self._admin_password + return self._config["global"]["admin_password"] + + @admin_password.setter + def admin_password(self, value: str): + self._admin_password = value + self._config["global"]["admin_password"] = value + + def set_admin_password_from_db(self, password: str): + """Set admin password from database""" + self._admin_password = password + + def set_debug_enabled(self, enabled: bool): + """Set debug mode enabled/disabled""" + if "debug" not in self._config: + self._config["debug"] = {} + self._config["debug"]["enabled"] = enabled + + @property + def image_timeout(self) -> int: + """Get image generation timeout in seconds""" + return self._config.get("generation", {}).get("image_timeout", 300) + + def set_image_timeout(self, timeout: int): + """Set image generation timeout in seconds""" + if "generation" not in self._config: + self._config["generation"] = {} + self._config["generation"]["image_timeout"] = timeout + + @property + def video_timeout(self) -> int: + """Get video generation timeout in seconds""" + return self._config.get("generation", {}).get("video_timeout", 1500) + + def set_video_timeout(self, timeout: int): + """Set video generation timeout in seconds""" + if "generation" not in self._config: + self._config["generation"] = {} + self._config["generation"]["video_timeout"] = timeout + + # Cache configuration + @property + def cache_enabled(self) -> bool: + """Get cache enabled status""" + return self._config.get("cache", {}).get("enabled", False) + + def set_cache_enabled(self, enabled: bool): + """Set cache enabled status""" + if "cache" not in self._config: + self._config["cache"] = {} + self._config["cache"]["enabled"] = enabled + + @property + def cache_timeout(self) -> int: + """Get cache timeout in seconds""" + return self._config.get("cache", {}).get("timeout", 7200) + + def set_cache_timeout(self, timeout: int): + """Set cache timeout in seconds""" + if "cache" not in self._config: + self._config["cache"] = {} + self._config["cache"]["timeout"] = timeout + + @property + def cache_base_url(self) -> str: + """Get cache base URL""" + return self._config.get("cache", {}).get("base_url", "") + + def set_cache_base_url(self, base_url: str): + """Set cache base URL""" + if "cache" not in self._config: + self._config["cache"] = {} + self._config["cache"]["base_url"] = base_url + +# Global config instance +config = Config() diff --git a/src/core/database.py b/src/core/database.py new file mode 100644 index 0000000..22c3bc5 --- /dev/null +++ b/src/core/database.py @@ -0,0 +1,879 @@ +"""Database storage layer for Flow2API""" +import aiosqlite +import json +from datetime import datetime +from typing import Optional, List +from pathlib import Path +from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, Project + + +class Database: + """SQLite database manager""" + + def __init__(self, db_path: str = None): + if db_path is None: + # Store database in data directory + data_dir = Path(__file__).parent.parent.parent / "data" + data_dir.mkdir(exist_ok=True) + db_path = str(data_dir / "flow.db") + self.db_path = db_path + + def db_exists(self) -> bool: + """Check if database file exists""" + return Path(self.db_path).exists() + + async def _table_exists(self, db, table_name: str) -> bool: + """Check if a table exists in the database""" + cursor = await db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + result = await cursor.fetchone() + return result is not None + + async def _column_exists(self, db, table_name: str, column_name: str) -> bool: + """Check if a column exists in a table""" + try: + cursor = await db.execute(f"PRAGMA table_info({table_name})") + columns = await cursor.fetchall() + return any(col[1] == column_name for col in columns) + except: + return False + + async def _ensure_config_rows(self, db, config_dict: dict = None): + """Ensure all config tables have their default rows + + Args: + db: Database connection + config_dict: Configuration dictionary from setting.toml (optional) + If None, use default values instead of reading from TOML. + """ + # Ensure admin_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM admin_config") + count = await cursor.fetchone() + if count[0] == 0: + admin_username = "admin" + admin_password = "admin" + api_key = "han1234" + + if config_dict: + global_config = config_dict.get("global", {}) + admin_username = global_config.get("admin_username", "admin") + admin_password = global_config.get("admin_password", "admin") + api_key = global_config.get("api_key", "han1234") + + await db.execute(""" + INSERT INTO admin_config (id, username, password, api_key) + VALUES (1, ?, ?, ?) + """, (admin_username, admin_password, api_key)) + + # Ensure proxy_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM proxy_config") + count = await cursor.fetchone() + if count[0] == 0: + proxy_enabled = False + proxy_url = None + + if config_dict: + proxy_config = config_dict.get("proxy", {}) + proxy_enabled = proxy_config.get("proxy_enabled", False) + proxy_url = proxy_config.get("proxy_url", "") + proxy_url = proxy_url if proxy_url else None + + await db.execute(""" + INSERT INTO proxy_config (id, enabled, proxy_url) + VALUES (1, ?, ?) + """, (proxy_enabled, proxy_url)) + + # Ensure generation_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM generation_config") + count = await cursor.fetchone() + if count[0] == 0: + image_timeout = 300 + video_timeout = 1500 + + if config_dict: + generation_config = config_dict.get("generation", {}) + image_timeout = generation_config.get("image_timeout", 300) + video_timeout = generation_config.get("video_timeout", 1500) + + await db.execute(""" + INSERT INTO generation_config (id, image_timeout, video_timeout) + VALUES (1, ?, ?) + """, (image_timeout, video_timeout)) + + # Ensure cache_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM cache_config") + count = await cursor.fetchone() + if count[0] == 0: + cache_enabled = False + cache_timeout = 7200 + cache_base_url = None + + if config_dict: + cache_config = config_dict.get("cache", {}) + cache_enabled = cache_config.get("enabled", False) + cache_timeout = cache_config.get("timeout", 7200) + cache_base_url = cache_config.get("base_url", "") + # Convert empty string to None + cache_base_url = cache_base_url if cache_base_url else None + + await db.execute(""" + INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) + VALUES (1, ?, ?, ?) + """, (cache_enabled, cache_timeout, cache_base_url)) + + async def check_and_migrate_db(self, config_dict: dict = None): + """Check database integrity and perform migrations if needed + + This method is called during upgrade mode to: + 1. Create missing tables (if they don't exist) + 2. Add missing columns to existing tables + 3. Ensure all config tables have default rows + + Args: + config_dict: Configuration dictionary from setting.toml (optional) + Used only to initialize missing config rows with default values. + Existing config rows will NOT be overwritten. + """ + async with aiosqlite.connect(self.db_path) as db: + print("Checking database integrity and performing migrations...") + + # ========== Step 1: Create missing tables ========== + # Check and create cache_config table if missing + if not await self._table_exists(db, "cache_config"): + print(" ✓ Creating missing table: cache_config") + await db.execute(""" + CREATE TABLE cache_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + cache_enabled BOOLEAN DEFAULT 0, + cache_timeout INTEGER DEFAULT 7200, + cache_base_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # ========== Step 2: Add missing columns to existing tables ========== + # Check and add missing columns to tokens table + if await self._table_exists(db, "tokens"): + columns_to_add = [ + ("at", "TEXT"), # Access Token + ("at_expires", "TIMESTAMP"), # AT expiration time + ("credits", "INTEGER DEFAULT 0"), # Balance + ("user_paygate_tier", "TEXT"), # User tier + ("current_project_id", "TEXT"), # Current project UUID + ("current_project_name", "TEXT"), # Project name + ("image_enabled", "BOOLEAN DEFAULT 1"), + ("video_enabled", "BOOLEAN DEFAULT 1"), + ("image_concurrency", "INTEGER DEFAULT -1"), + ("video_concurrency", "INTEGER DEFAULT -1"), + ] + + for col_name, col_type in columns_to_add: + if not await self._column_exists(db, "tokens", col_name): + try: + await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}") + print(f" ✓ Added column '{col_name}' to tokens table") + except Exception as e: + print(f" ✗ Failed to add column '{col_name}': {e}") + + # Check and add missing columns to token_stats table + if await self._table_exists(db, "token_stats"): + stats_columns_to_add = [ + ("today_image_count", "INTEGER DEFAULT 0"), + ("today_video_count", "INTEGER DEFAULT 0"), + ("today_error_count", "INTEGER DEFAULT 0"), + ("today_date", "DATE"), + ] + + for col_name, col_type in stats_columns_to_add: + if not await self._column_exists(db, "token_stats", col_name): + try: + await db.execute(f"ALTER TABLE token_stats ADD COLUMN {col_name} {col_type}") + print(f" ✓ Added column '{col_name}' to token_stats table") + except Exception as e: + print(f" ✗ Failed to add column '{col_name}': {e}") + + # ========== Step 3: Ensure all config tables have default rows ========== + # Note: This will NOT overwrite existing config rows + # It only ensures missing rows are created with default values + await self._ensure_config_rows(db, config_dict=None) + + await db.commit() + print("Database migration check completed.") + + async def init_db(self): + """Initialize database tables""" + async with aiosqlite.connect(self.db_path) as db: + # Tokens table (Flow2API版本) + await db.execute(""" + CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + st TEXT UNIQUE NOT NULL, + at TEXT, + at_expires TIMESTAMP, + email TEXT NOT NULL, + name TEXT, + remark TEXT, + is_active BOOLEAN DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP, + use_count INTEGER DEFAULT 0, + credits INTEGER DEFAULT 0, + user_paygate_tier TEXT, + current_project_id TEXT, + current_project_name TEXT, + image_enabled BOOLEAN DEFAULT 1, + video_enabled BOOLEAN DEFAULT 1, + image_concurrency INTEGER DEFAULT -1, + video_concurrency INTEGER DEFAULT -1 + ) + """) + + # Projects table (新增) + await db.execute(""" + CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id TEXT UNIQUE NOT NULL, + token_id INTEGER NOT NULL, + project_name TEXT NOT NULL, + tool_name TEXT DEFAULT 'PINHOLE', + is_active BOOLEAN DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Token stats table + await db.execute(""" + CREATE TABLE IF NOT EXISTS token_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER NOT NULL, + image_count INTEGER DEFAULT 0, + video_count INTEGER DEFAULT 0, + success_count INTEGER DEFAULT 0, + error_count INTEGER DEFAULT 0, + last_success_at TIMESTAMP, + last_error_at TIMESTAMP, + today_image_count INTEGER DEFAULT 0, + today_video_count INTEGER DEFAULT 0, + today_error_count INTEGER DEFAULT 0, + today_date DATE, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Tasks table + await db.execute(""" + CREATE TABLE IF NOT EXISTS tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT UNIQUE NOT NULL, + token_id INTEGER NOT NULL, + model TEXT NOT NULL, + prompt TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'processing', + progress INTEGER DEFAULT 0, + result_urls TEXT, + error_message TEXT, + scene_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Request logs table + await db.execute(""" + CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER, + operation TEXT NOT NULL, + request_body TEXT, + response_body TEXT, + status_code INTEGER NOT NULL, + duration FLOAT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Admin config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS admin_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + username TEXT DEFAULT 'admin', + password TEXT DEFAULT 'admin', + api_key TEXT DEFAULT 'han1234', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Proxy config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS proxy_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + enabled BOOLEAN DEFAULT 0, + proxy_url TEXT, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Generation config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS generation_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + image_timeout INTEGER DEFAULT 300, + video_timeout INTEGER DEFAULT 1500, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Cache config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS cache_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + cache_enabled BOOLEAN DEFAULT 0, + cache_timeout INTEGER DEFAULT 7200, + cache_base_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes + await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_token_st ON tokens(st)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_project_id ON projects(project_id)") + + # Migrate request_logs table if needed + await self._migrate_request_logs(db) + + await db.commit() + + async def _migrate_request_logs(self, db): + """Migrate request_logs table from old schema to new schema""" + try: + # Check if old columns exist + has_model = await self._column_exists(db, "request_logs", "model") + has_operation = await self._column_exists(db, "request_logs", "operation") + + if has_model and not has_operation: + # Old schema detected, need migration + print("🔄 检测到旧的request_logs表结构,开始迁移...") + + # Rename old table + await db.execute("ALTER TABLE request_logs RENAME TO request_logs_old") + + # Create new table with new schema + await db.execute(""" + CREATE TABLE request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER, + operation TEXT NOT NULL, + request_body TEXT, + response_body TEXT, + status_code INTEGER NOT NULL, + duration FLOAT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Migrate data from old table (basic migration) + await db.execute(""" + INSERT INTO request_logs (token_id, operation, request_body, status_code, duration, created_at) + SELECT + token_id, + model as operation, + json_object('model', model, 'prompt', substr(prompt, 1, 100)) as request_body, + CASE + WHEN status = 'completed' THEN 200 + WHEN status = 'failed' THEN 500 + ELSE 0 + END as status_code, + response_time as duration, + created_at + FROM request_logs_old + """) + + # Drop old table + await db.execute("DROP TABLE request_logs_old") + + print("✅ request_logs表迁移完成") + except Exception as e: + print(f"⚠️ request_logs表迁移失败: {e}") + # Continue even if migration fails + + # Token operations + async def add_token(self, token: Token) -> int: + """Add a new token""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute(""" + INSERT INTO tokens (st, at, at_expires, email, name, remark, is_active, + credits, user_paygate_tier, current_project_id, current_project_name, + image_enabled, video_enabled, image_concurrency, video_concurrency) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, (token.st, token.at, token.at_expires, token.email, token.name, token.remark, + token.is_active, token.credits, token.user_paygate_tier, + token.current_project_id, token.current_project_name, + token.image_enabled, token.video_enabled, + token.image_concurrency, token.video_concurrency)) + await db.commit() + token_id = cursor.lastrowid + + # Create stats entry + await db.execute(""" + INSERT INTO token_stats (token_id) VALUES (?) + """, (token_id,)) + await db.commit() + + return token_id + + async def get_token(self, token_id: int) -> Optional[Token]: + """Get token by ID""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE id = ?", (token_id,)) + row = await cursor.fetchone() + if row: + return Token(**dict(row)) + return None + + async def get_token_by_st(self, st: str) -> Optional[Token]: + """Get token by ST""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE st = ?", (st,)) + row = await cursor.fetchone() + if row: + return Token(**dict(row)) + return None + + async def get_all_tokens(self) -> List[Token]: + """Get all tokens""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens ORDER BY created_at DESC") + rows = await cursor.fetchall() + return [Token(**dict(row)) for row in rows] + + async def get_active_tokens(self) -> List[Token]: + """Get all active tokens""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE is_active = 1 ORDER BY last_used_at ASC") + rows = await cursor.fetchall() + return [Token(**dict(row)) for row in rows] + + async def update_token(self, token_id: int, **kwargs): + """Update token fields""" + async with aiosqlite.connect(self.db_path) as db: + updates = [] + params = [] + + for key, value in kwargs.items(): + if value is not None: + updates.append(f"{key} = ?") + params.append(value) + + if updates: + params.append(token_id) + query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?" + await db.execute(query, params) + await db.commit() + + async def delete_token(self, token_id: int): + """Delete token and related data""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute("DELETE FROM token_stats WHERE token_id = ?", (token_id,)) + await db.execute("DELETE FROM projects WHERE token_id = ?", (token_id,)) + await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,)) + await db.commit() + + # Project operations + async def add_project(self, project: Project) -> int: + """Add a new project""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute(""" + INSERT INTO projects (project_id, token_id, project_name, tool_name, is_active) + VALUES (?, ?, ?, ?, ?) + """, (project.project_id, project.token_id, project.project_name, + project.tool_name, project.is_active)) + await db.commit() + return cursor.lastrowid + + async def get_project_by_id(self, project_id: str) -> Optional[Project]: + """Get project by UUID""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM projects WHERE project_id = ?", (project_id,)) + row = await cursor.fetchone() + if row: + return Project(**dict(row)) + return None + + async def get_projects_by_token(self, token_id: int) -> List[Project]: + """Get all projects for a token""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT * FROM projects WHERE token_id = ? ORDER BY created_at DESC", + (token_id,) + ) + rows = await cursor.fetchall() + return [Project(**dict(row)) for row in rows] + + async def delete_project(self, project_id: str): + """Delete project""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute("DELETE FROM projects WHERE project_id = ?", (project_id,)) + await db.commit() + + # Task operations + async def create_task(self, task: Task) -> int: + """Create a new task""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute(""" + INSERT INTO tasks (task_id, token_id, model, prompt, status, progress, scene_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, (task.task_id, task.token_id, task.model, task.prompt, + task.status, task.progress, task.scene_id)) + await db.commit() + return cursor.lastrowid + + async def get_task(self, task_id: str) -> Optional[Task]: + """Get task by ID""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)) + row = await cursor.fetchone() + if row: + task_dict = dict(row) + # Parse result_urls from JSON + if task_dict.get("result_urls"): + task_dict["result_urls"] = json.loads(task_dict["result_urls"]) + return Task(**task_dict) + return None + + async def update_task(self, task_id: str, **kwargs): + """Update task""" + async with aiosqlite.connect(self.db_path) as db: + updates = [] + params = [] + + for key, value in kwargs.items(): + if value is not None: + # Convert list to JSON string for result_urls + if key == "result_urls" and isinstance(value, list): + value = json.dumps(value) + updates.append(f"{key} = ?") + params.append(value) + + if updates: + params.append(task_id) + query = f"UPDATE tasks SET {', '.join(updates)} WHERE task_id = ?" + await db.execute(query, params) + await db.commit() + + # Token stats operations (kept for compatibility, now delegates to specific methods) + async def increment_token_stats(self, token_id: int, stat_type: str): + """Increment token statistics (delegates to specific methods)""" + if stat_type == "image": + await self.increment_image_count(token_id) + elif stat_type == "video": + await self.increment_video_count(token_id) + elif stat_type == "error": + await self.increment_error_count(token_id) + + async def get_token_stats(self, token_id: int) -> Optional[TokenStats]: + """Get token statistics""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + if row: + return TokenStats(**dict(row)) + return None + + async def increment_image_count(self, token_id: int): + """Increment image generation count with daily reset""" + from datetime import date + async with aiosqlite.connect(self.db_path) as db: + today = str(date.today()) + # Get current stats + cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + + # If date changed, reset today's count + if row and row[0] != today: + await db.execute(""" + UPDATE token_stats + SET image_count = image_count + 1, + today_image_count = 1, + today_date = ? + WHERE token_id = ? + """, (today, token_id)) + else: + # Same day, just increment both + await db.execute(""" + UPDATE token_stats + SET image_count = image_count + 1, + today_image_count = today_image_count + 1, + today_date = ? + WHERE token_id = ? + """, (today, token_id)) + await db.commit() + + async def increment_video_count(self, token_id: int): + """Increment video generation count with daily reset""" + from datetime import date + async with aiosqlite.connect(self.db_path) as db: + today = str(date.today()) + # Get current stats + cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + + # If date changed, reset today's count + if row and row[0] != today: + await db.execute(""" + UPDATE token_stats + SET video_count = video_count + 1, + today_video_count = 1, + today_date = ? + WHERE token_id = ? + """, (today, token_id)) + else: + # Same day, just increment both + await db.execute(""" + UPDATE token_stats + SET video_count = video_count + 1, + today_video_count = today_video_count + 1, + today_date = ? + WHERE token_id = ? + """, (today, token_id)) + await db.commit() + + async def increment_error_count(self, token_id: int): + """Increment error count with daily reset""" + from datetime import date + async with aiosqlite.connect(self.db_path) as db: + today = str(date.today()) + # Get current stats + cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + + # If date changed, reset today's error count + if row and row[0] != today: + await db.execute(""" + UPDATE token_stats + SET error_count = error_count + 1, + today_error_count = 1, + today_date = ?, + last_error_at = CURRENT_TIMESTAMP + WHERE token_id = ? + """, (today, token_id)) + else: + # Same day, just increment both + await db.execute(""" + UPDATE token_stats + SET error_count = error_count + 1, + today_error_count = today_error_count + 1, + today_date = ?, + last_error_at = CURRENT_TIMESTAMP + WHERE token_id = ? + """, (today, token_id)) + await db.commit() + + # Config operations + async def get_admin_config(self) -> Optional[AdminConfig]: + """Get admin configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM admin_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return AdminConfig(**dict(row)) + return None + + async def update_admin_config(self, **kwargs): + """Update admin configuration""" + async with aiosqlite.connect(self.db_path) as db: + updates = [] + params = [] + + for key, value in kwargs.items(): + if value is not None: + updates.append(f"{key} = ?") + params.append(value) + + if updates: + updates.append("updated_at = CURRENT_TIMESTAMP") + query = f"UPDATE admin_config SET {', '.join(updates)} WHERE id = 1" + await db.execute(query, params) + await db.commit() + + async def get_proxy_config(self) -> Optional[ProxyConfig]: + """Get proxy configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM proxy_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return ProxyConfig(**dict(row)) + return None + + async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str] = None): + """Update proxy configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE proxy_config + SET enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (enabled, proxy_url)) + await db.commit() + + async def get_generation_config(self) -> Optional[GenerationConfig]: + """Get generation configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return GenerationConfig(**dict(row)) + return None + + async def update_generation_config(self, image_timeout: int, video_timeout: int): + """Update generation configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE generation_config + SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (image_timeout, video_timeout)) + await db.commit() + + # Request log operations + async def add_request_log(self, log: RequestLog): + """Add request log""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + INSERT INTO request_logs (token_id, operation, request_body, response_body, status_code, duration) + VALUES (?, ?, ?, ?, ?, ?) + """, (log.token_id, log.operation, log.request_body, log.response_body, + log.status_code, log.duration)) + await db.commit() + + async def get_logs(self, limit: int = 100, token_id: Optional[int] = None): + """Get request logs with token email""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + if token_id: + cursor = await db.execute(""" + SELECT + rl.id, + rl.token_id, + rl.operation, + rl.request_body, + rl.response_body, + rl.status_code, + rl.duration, + rl.created_at, + t.email as token_email, + t.name as token_username + FROM request_logs rl + LEFT JOIN tokens t ON rl.token_id = t.id + WHERE rl.token_id = ? + ORDER BY rl.created_at DESC + LIMIT ? + """, (token_id, limit)) + else: + cursor = await db.execute(""" + SELECT + rl.id, + rl.token_id, + rl.operation, + rl.request_body, + rl.response_body, + rl.status_code, + rl.duration, + rl.created_at, + t.email as token_email, + t.name as token_username + FROM request_logs rl + LEFT JOIN tokens t ON rl.token_id = t.id + ORDER BY rl.created_at DESC + LIMIT ? + """, (limit,)) + + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True): + """ + Initialize database configuration from setting.toml + + Args: + config_dict: Configuration dictionary from setting.toml + is_first_startup: If True, initialize all config rows from setting.toml. + If False (upgrade mode), only ensure missing config rows exist with default values. + """ + async with aiosqlite.connect(self.db_path) as db: + if is_first_startup: + # First startup: Initialize all config tables with values from setting.toml + await self._ensure_config_rows(db, config_dict) + else: + # Upgrade mode: Only ensure missing config rows exist (with default values, not from TOML) + await self._ensure_config_rows(db, config_dict=None) + + await db.commit() + + # Cache config operations + async def get_cache_config(self) -> CacheConfig: + """Get cache configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return CacheConfig(**dict(row)) + # Return default if not found + return CacheConfig(cache_enabled=False, cache_timeout=7200) + + async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None): + """Update cache configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + # Get current values + cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") + row = await cursor.fetchone() + + if row: + current = dict(row) + # Use new values if provided, otherwise keep existing + new_enabled = enabled if enabled is not None else current.get("cache_enabled", False) + new_timeout = timeout if timeout is not None else current.get("cache_timeout", 7200) + new_base_url = base_url if base_url is not None else current.get("cache_base_url") + + # If base_url is explicitly set to empty string, treat as None + if base_url == "": + new_base_url = None + + await db.execute(""" + UPDATE cache_config + SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (new_enabled, new_timeout, new_base_url)) + else: + # Insert default row if not exists + new_enabled = enabled if enabled is not None else False + new_timeout = timeout if timeout is not None else 7200 + new_base_url = base_url if base_url is not None else None + + await db.execute(""" + INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) + VALUES (1, ?, ?, ?) + """, (new_enabled, new_timeout, new_base_url)) + + await db.commit() diff --git a/src/core/logger.py b/src/core/logger.py new file mode 100644 index 0000000..798ec05 --- /dev/null +++ b/src/core/logger.py @@ -0,0 +1,243 @@ +"""Debug logger module for detailed API request/response logging""" +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional +from .config import config + +class DebugLogger: + """Debug logger for API requests and responses""" + + def __init__(self): + self.log_file = Path("logs.txt") + self._setup_logger() + + def _setup_logger(self): + """Setup file logger""" + # Create logger + self.logger = logging.getLogger("debug_logger") + self.logger.setLevel(logging.DEBUG) + + # Remove existing handlers + self.logger.handlers.clear() + + # Create file handler + file_handler = logging.FileHandler( + self.log_file, + mode='a', + encoding='utf-8' + ) + file_handler.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + '%(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + + # Add handler + self.logger.addHandler(file_handler) + + # Prevent propagation to root logger + self.logger.propagate = False + + def _mask_token(self, token: str) -> str: + """Mask token for logging (show first 6 and last 6 characters)""" + if not config.debug_mask_token or len(token) <= 12: + return token + return f"{token[:6]}...{token[-6:]}" + + def _format_timestamp(self) -> str: + """Format current timestamp""" + return datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + + def _write_separator(self, char: str = "=", length: int = 100): + """Write separator line""" + self.logger.info(char * length) + + def log_request( + self, + method: str, + url: str, + headers: Dict[str, str], + body: Optional[Any] = None, + files: Optional[Dict] = None, + proxy: Optional[str] = None + ): + """Log API request details to log.txt""" + + if not config.debug_enabled or not config.debug_log_requests: + return + + try: + self._write_separator() + self.logger.info(f"🔵 [REQUEST] {self._format_timestamp()}") + self._write_separator("-") + + # Basic info + self.logger.info(f"Method: {method}") + self.logger.info(f"URL: {url}") + + # Headers + self.logger.info("\n📋 Headers:") + masked_headers = dict(headers) + if "Authorization" in masked_headers or "authorization" in masked_headers: + auth_key = "Authorization" if "Authorization" in masked_headers else "authorization" + auth_value = masked_headers[auth_key] + if auth_value.startswith("Bearer "): + token = auth_value[7:] + masked_headers[auth_key] = f"Bearer {self._mask_token(token)}" + + # Mask Cookie header (ST token) + if "Cookie" in masked_headers: + cookie_value = masked_headers["Cookie"] + if "__Secure-next-auth.session-token=" in cookie_value: + parts = cookie_value.split("=", 1) + if len(parts) == 2: + st_token = parts[1].split(";")[0] + masked_headers["Cookie"] = f"__Secure-next-auth.session-token={self._mask_token(st_token)}" + + for key, value in masked_headers.items(): + self.logger.info(f" {key}: {value}") + + # Body + if body is not None: + self.logger.info("\n📦 Request Body:") + if isinstance(body, (dict, list)): + body_str = json.dumps(body, indent=2, ensure_ascii=False) + self.logger.info(body_str) + else: + self.logger.info(str(body)) + + # Files + if files: + self.logger.info("\n📎 Files:") + try: + if hasattr(files, 'keys') and callable(getattr(files, 'keys', None)): + for key in files.keys(): + self.logger.info(f" {key}: ") + else: + self.logger.info(" ") + except (AttributeError, TypeError): + self.logger.info(" ") + + # Proxy + if proxy: + self.logger.info(f"\n🌐 Proxy: {proxy}") + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging request: {e}") + + def log_response( + self, + status_code: int, + headers: Dict[str, str], + body: Any, + duration_ms: Optional[float] = None + ): + """Log API response details to log.txt""" + + if not config.debug_enabled or not config.debug_log_responses: + return + + try: + self._write_separator() + self.logger.info(f"🟢 [RESPONSE] {self._format_timestamp()}") + self._write_separator("-") + + # Status + status_emoji = "✅" if 200 <= status_code < 300 else "❌" + self.logger.info(f"Status: {status_code} {status_emoji}") + + # Duration + if duration_ms is not None: + self.logger.info(f"Duration: {duration_ms:.2f}ms") + + # Headers + self.logger.info("\n📋 Response Headers:") + for key, value in headers.items(): + self.logger.info(f" {key}: {value}") + + # Body + self.logger.info("\n📦 Response Body:") + if isinstance(body, (dict, list)): + body_str = json.dumps(body, indent=2, ensure_ascii=False) + self.logger.info(body_str) + elif isinstance(body, str): + # Try to parse as JSON + try: + parsed = json.loads(body) + body_str = json.dumps(parsed, indent=2, ensure_ascii=False) + self.logger.info(body_str) + except: + # Not JSON, log as text (limit length) + if len(body) > 2000: + self.logger.info(f"{body[:2000]}... (truncated)") + else: + self.logger.info(body) + else: + self.logger.info(str(body)) + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging response: {e}") + + def log_error( + self, + error_message: str, + status_code: Optional[int] = None, + response_text: Optional[str] = None + ): + """Log API error details to log.txt""" + + if not config.debug_enabled: + return + + try: + self._write_separator() + self.logger.info(f"🔴 [ERROR] {self._format_timestamp()}") + self._write_separator("-") + + if status_code: + self.logger.info(f"Status Code: {status_code}") + + self.logger.info(f"Error Message: {error_message}") + + if response_text: + self.logger.info("\n📦 Error Response:") + # Try to parse as JSON + try: + parsed = json.loads(response_text) + body_str = json.dumps(parsed, indent=2, ensure_ascii=False) + self.logger.info(body_str) + except: + # Not JSON, log as text + if len(response_text) > 2000: + self.logger.info(f"{response_text[:2000]}... (truncated)") + else: + self.logger.info(response_text) + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging error: {e}") + + def log_info(self, message: str): + """Log general info message to log.txt""" + if not config.debug_enabled: + return + try: + self.logger.info(f"ℹ️ [{self._format_timestamp()}] {message}") + except Exception as e: + self.logger.error(f"Error logging info: {e}") + +# Global debug logger instance +debug_logger = DebugLogger() diff --git a/src/core/models.py b/src/core/models.py new file mode 100644 index 0000000..c9d1fa9 --- /dev/null +++ b/src/core/models.py @@ -0,0 +1,145 @@ +"""Data models for Flow2API""" +from pydantic import BaseModel +from typing import Optional, List, Union, Any +from datetime import datetime + + +class Token(BaseModel): + """Token model for Flow2API""" + id: Optional[int] = None + + # 认证信息 (核心) + st: str # Session Token (__Secure-next-auth.session-token) + at: Optional[str] = None # Access Token (从ST转换而来) + at_expires: Optional[datetime] = None # AT过期时间 + + # 基础信息 + email: str + name: Optional[str] = "" + remark: Optional[str] = None + is_active: bool = True + created_at: Optional[datetime] = None + last_used_at: Optional[datetime] = None + use_count: int = 0 + + # VideoFX特有字段 + credits: int = 0 # 剩余credits + user_paygate_tier: Optional[str] = None # PAYGATE_TIER_ONE + + # 项目管理 + current_project_id: Optional[str] = None # 当前使用的项目UUID + current_project_name: Optional[str] = None # 项目名称 + + # 功能开关 + image_enabled: bool = True + video_enabled: bool = True + + # 并发限制 + image_concurrency: int = -1 # -1表示无限制 + video_concurrency: int = -1 # -1表示无限制 + + +class Project(BaseModel): + """Project model for VideoFX""" + id: Optional[int] = None + project_id: str # VideoFX项目UUID + token_id: int # 关联的Token ID + project_name: str # 项目名称 + tool_name: str = "PINHOLE" # 工具名称,固定为PINHOLE + is_active: bool = True + created_at: Optional[datetime] = None + + +class TokenStats(BaseModel): + """Token statistics""" + token_id: int + image_count: int = 0 + video_count: int = 0 + success_count: int = 0 + error_count: int = 0 + last_success_at: Optional[datetime] = None + last_error_at: Optional[datetime] = None + # 今日统计 + today_image_count: int = 0 + today_video_count: int = 0 + today_error_count: int = 0 + today_date: Optional[str] = None + + +class Task(BaseModel): + """Generation task""" + id: Optional[int] = None + task_id: str # Flow API返回的operation name + token_id: int + model: str + prompt: str + status: str # processing, completed, failed + progress: int = 0 # 0-100 + result_urls: Optional[List[str]] = None + error_message: Optional[str] = None + scene_id: Optional[str] = None # Flow API的sceneId + created_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class RequestLog(BaseModel): + """API request log""" + id: Optional[int] = None + token_id: Optional[int] = None + operation: str + request_body: Optional[str] = None + response_body: Optional[str] = None + status_code: int + duration: float + created_at: Optional[datetime] = None + + +class AdminConfig(BaseModel): + """Admin configuration""" + id: int = 1 + username: str + password: str + api_key: str + + +class ProxyConfig(BaseModel): + """Proxy configuration""" + id: int = 1 + enabled: bool = False + proxy_url: Optional[str] = None + + +class GenerationConfig(BaseModel): + """Generation timeout configuration""" + id: int = 1 + image_timeout: int = 300 # seconds + video_timeout: int = 1500 # seconds + + +class CacheConfig(BaseModel): + """Cache configuration""" + id: int = 1 + cache_enabled: bool = False + cache_timeout: int = 7200 # seconds (2 hours) + cache_base_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +# OpenAI Compatible Request Models +class ChatMessage(BaseModel): + """Chat message""" + role: str + content: Union[str, List[dict]] # string or multimodal array + + +class ChatCompletionRequest(BaseModel): + """Chat completion request (OpenAI compatible)""" + model: str + messages: List[ChatMessage] + stream: bool = False + temperature: Optional[float] = None + max_tokens: Optional[int] = None + # Flow2API specific parameters + image: Optional[str] = None # Base64 encoded image (deprecated, use messages) + video: Optional[str] = None # Base64 encoded video (deprecated) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..6d0b9ba --- /dev/null +++ b/src/main.py @@ -0,0 +1,162 @@ +"""FastAPI application initialization""" +from fastapi import FastAPI +from fastapi.responses import HTMLResponse, FileResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +from pathlib import Path + +from .core.config import config +from .core.database import Database +from .services.flow_client import FlowClient +from .services.proxy_manager import ProxyManager +from .services.token_manager import TokenManager +from .services.load_balancer import LoadBalancer +from .services.concurrency_manager import ConcurrencyManager +from .services.generation_handler import GenerationHandler +from .api import routes, admin + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager""" + # Startup + print("=" * 60) + print("Flow2API Starting...") + print("=" * 60) + + # Get config from setting.toml + config_dict = config.get_raw_config() + + # Check if database exists (determine if first startup) + is_first_startup = not db.db_exists() + + # Initialize database tables structure + await db.init_db() + + # Handle database initialization based on startup type + if is_first_startup: + print("🎉 First startup detected. Initializing database and configuration from setting.toml...") + await db.init_config_from_toml(config_dict, is_first_startup=True) + print("✓ Database and configuration initialized successfully.") + else: + print("🔄 Existing database detected. Checking for missing tables and columns...") + await db.check_and_migrate_db(config_dict) + print("✓ Database migration check completed.") + + # Load admin config from database + admin_config = await db.get_admin_config() + if admin_config: + config.set_admin_username_from_db(admin_config.username) + config.set_admin_password_from_db(admin_config.password) + config.api_key = admin_config.api_key + + # Load cache configuration from database + cache_config = await db.get_cache_config() + config.set_cache_enabled(cache_config.cache_enabled) + config.set_cache_timeout(cache_config.cache_timeout) + config.set_cache_base_url(cache_config.cache_base_url or "") + + # Load generation configuration from database + generation_config = await db.get_generation_config() + config.set_image_timeout(generation_config.image_timeout) + config.set_video_timeout(generation_config.video_timeout) + + # Initialize concurrency manager + tokens = await token_manager.get_all_tokens() + await concurrency_manager.initialize(tokens) + + # Start file cache cleanup task + await generation_handler.file_cache.start_cleanup_task() + + print(f"✓ Database initialized") + print(f"✓ Total tokens: {len(tokens)}") + print(f"✓ Cache: {'Enabled' if config.cache_enabled else 'Disabled'} (timeout: {config.cache_timeout}s)") + print(f"✓ File cache cleanup task started") + print(f"✓ Server running on http://{config.server_host}:{config.server_port}") + print("=" * 60) + + yield + + # Shutdown + print("Flow2API Shutting down...") + # Stop file cache cleanup task + await generation_handler.file_cache.stop_cleanup_task() + print("✓ File cache cleanup task stopped") + + +# Initialize components +db = Database() +proxy_manager = ProxyManager(db) +flow_client = FlowClient(proxy_manager) +token_manager = TokenManager(db, flow_client) +concurrency_manager = ConcurrencyManager() +load_balancer = LoadBalancer(token_manager, concurrency_manager) +generation_handler = GenerationHandler( + flow_client, + token_manager, + load_balancer, + db, + concurrency_manager, + proxy_manager # 添加 proxy_manager 参数 +) + +# Set dependencies +routes.set_generation_handler(generation_handler) +admin.set_dependencies(token_manager, proxy_manager, db) + +# Create FastAPI app +app = FastAPI( + title="Flow2API", + description="OpenAI-compatible API for Google VideoFX (Veo)", + version="1.0.0", + lifespan=lifespan +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +app.include_router(routes.router) +app.include_router(admin.router) + +# Static files - serve tmp directory for cached files +tmp_dir = Path(__file__).parent.parent / "tmp" +tmp_dir.mkdir(exist_ok=True) +app.mount("/tmp", StaticFiles(directory=str(tmp_dir)), name="tmp") + +# HTML routes for frontend +static_path = Path(__file__).parent.parent / "static" + + +@app.get("/", response_class=HTMLResponse) +async def index(): + """Redirect to login page""" + login_file = static_path / "login.html" + if login_file.exists(): + return FileResponse(str(login_file)) + return HTMLResponse(content="

Flow2API

Frontend not found

", status_code=404) + + +@app.get("/login", response_class=HTMLResponse) +async def login_page(): + """Login page""" + login_file = static_path / "login.html" + if login_file.exists(): + return FileResponse(str(login_file)) + return HTMLResponse(content="

Login Page Not Found

", status_code=404) + + +@app.get("/manage", response_class=HTMLResponse) +async def manage_page(): + """Management console page""" + manage_file = static_path / "manage.html" + if manage_file.exists(): + return FileResponse(str(manage_file)) + return HTMLResponse(content="

Management Page Not Found

", status_code=404) diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..163be89 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,17 @@ +"""Services modules""" + +from .flow_client import FlowClient +from .proxy_manager import ProxyManager +from .load_balancer import LoadBalancer +from .concurrency_manager import ConcurrencyManager +from .token_manager import TokenManager +from .generation_handler import GenerationHandler + +__all__ = [ + "FlowClient", + "ProxyManager", + "LoadBalancer", + "ConcurrencyManager", + "TokenManager", + "GenerationHandler" +] diff --git a/src/services/concurrency_manager.py b/src/services/concurrency_manager.py new file mode 100644 index 0000000..fe0f9de --- /dev/null +++ b/src/services/concurrency_manager.py @@ -0,0 +1,190 @@ +"""Concurrency manager for token-based rate limiting""" +import asyncio +from typing import Dict, Optional +from ..core.logger import debug_logger + + +class ConcurrencyManager: + """Manages concurrent request limits for each token""" + + def __init__(self): + """Initialize concurrency manager""" + self._image_concurrency: Dict[int, int] = {} # token_id -> remaining image concurrency + self._video_concurrency: Dict[int, int] = {} # token_id -> remaining video concurrency + self._lock = asyncio.Lock() # Protect concurrent access + + async def initialize(self, tokens: list): + """ + Initialize concurrency counters from token list + + Args: + tokens: List of Token objects with image_concurrency and video_concurrency fields + """ + async with self._lock: + for token in tokens: + if token.image_concurrency and token.image_concurrency > 0: + self._image_concurrency[token.id] = token.image_concurrency + if token.video_concurrency and token.video_concurrency > 0: + self._video_concurrency[token.id] = token.video_concurrency + + debug_logger.log_info(f"Concurrency manager initialized with {len(tokens)} tokens") + + async def can_use_image(self, token_id: int) -> bool: + """ + Check if token can be used for image generation + + Args: + token_id: Token ID + + Returns: + True if token has available image concurrency, False if concurrency is 0 + """ + async with self._lock: + # If not in dict, it means no limit (-1) + if token_id not in self._image_concurrency: + return True + + remaining = self._image_concurrency[token_id] + if remaining <= 0: + debug_logger.log_info(f"Token {token_id} image concurrency exhausted (remaining: {remaining})") + return False + + return True + + async def can_use_video(self, token_id: int) -> bool: + """ + Check if token can be used for video generation + + Args: + token_id: Token ID + + Returns: + True if token has available video concurrency, False if concurrency is 0 + """ + async with self._lock: + # If not in dict, it means no limit (-1) + if token_id not in self._video_concurrency: + return True + + remaining = self._video_concurrency[token_id] + if remaining <= 0: + debug_logger.log_info(f"Token {token_id} video concurrency exhausted (remaining: {remaining})") + return False + + return True + + async def acquire_image(self, token_id: int) -> bool: + """ + Acquire image concurrency slot + + Args: + token_id: Token ID + + Returns: + True if acquired, False if not available + """ + async with self._lock: + if token_id not in self._image_concurrency: + # No limit + return True + + if self._image_concurrency[token_id] <= 0: + return False + + self._image_concurrency[token_id] -= 1 + debug_logger.log_info(f"Token {token_id} acquired image slot (remaining: {self._image_concurrency[token_id]})") + return True + + async def acquire_video(self, token_id: int) -> bool: + """ + Acquire video concurrency slot + + Args: + token_id: Token ID + + Returns: + True if acquired, False if not available + """ + async with self._lock: + if token_id not in self._video_concurrency: + # No limit + return True + + if self._video_concurrency[token_id] <= 0: + return False + + self._video_concurrency[token_id] -= 1 + debug_logger.log_info(f"Token {token_id} acquired video slot (remaining: {self._video_concurrency[token_id]})") + return True + + async def release_image(self, token_id: int): + """ + Release image concurrency slot + + Args: + token_id: Token ID + """ + async with self._lock: + if token_id in self._image_concurrency: + self._image_concurrency[token_id] += 1 + debug_logger.log_info(f"Token {token_id} released image slot (remaining: {self._image_concurrency[token_id]})") + + async def release_video(self, token_id: int): + """ + Release video concurrency slot + + Args: + token_id: Token ID + """ + async with self._lock: + if token_id in self._video_concurrency: + self._video_concurrency[token_id] += 1 + debug_logger.log_info(f"Token {token_id} released video slot (remaining: {self._video_concurrency[token_id]})") + + async def get_image_remaining(self, token_id: int) -> Optional[int]: + """ + Get remaining image concurrency for token + + Args: + token_id: Token ID + + Returns: + Remaining count or None if no limit + """ + async with self._lock: + return self._image_concurrency.get(token_id) + + async def get_video_remaining(self, token_id: int) -> Optional[int]: + """ + Get remaining video concurrency for token + + Args: + token_id: Token ID + + Returns: + Remaining count or None if no limit + """ + async with self._lock: + return self._video_concurrency.get(token_id) + + async def reset_token(self, token_id: int, image_concurrency: int = -1, video_concurrency: int = -1): + """ + Reset concurrency counters for a token + + Args: + token_id: Token ID + image_concurrency: New image concurrency limit (-1 for no limit) + video_concurrency: New video concurrency limit (-1 for no limit) + """ + async with self._lock: + if image_concurrency > 0: + self._image_concurrency[token_id] = image_concurrency + elif token_id in self._image_concurrency: + del self._image_concurrency[token_id] + + if video_concurrency > 0: + self._video_concurrency[token_id] = video_concurrency + elif token_id in self._video_concurrency: + del self._video_concurrency[token_id] + + debug_logger.log_info(f"Token {token_id} concurrency reset (image: {image_concurrency}, video: {video_concurrency})") diff --git a/src/services/file_cache.py b/src/services/file_cache.py new file mode 100644 index 0000000..5866d21 --- /dev/null +++ b/src/services/file_cache.py @@ -0,0 +1,199 @@ +"""File caching service""" +import os +import asyncio +import hashlib +import time +from pathlib import Path +from typing import Optional +from datetime import datetime, timedelta +from curl_cffi.requests import AsyncSession +from ..core.config import config +from ..core.logger import debug_logger + + +class FileCache: + """File caching service for videos""" + + def __init__(self, cache_dir: str = "tmp", default_timeout: int = 7200, proxy_manager=None): + """ + Initialize file cache + + Args: + cache_dir: Cache directory path + default_timeout: Default cache timeout in seconds (default: 2 hours) + proxy_manager: ProxyManager instance for downloading files + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + self.default_timeout = default_timeout + self.proxy_manager = proxy_manager + self._cleanup_task = None + + async def start_cleanup_task(self): + """Start background cleanup task""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop_cleanup_task(self): + """Stop background cleanup task""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + async def _cleanup_loop(self): + """Background task to clean up expired files""" + while True: + try: + await asyncio.sleep(300) # Check every 5 minutes + await self._cleanup_expired_files() + except asyncio.CancelledError: + break + except Exception as e: + debug_logger.log_error( + error_message=f"Cleanup task error: {str(e)}", + status_code=0, + response_text="" + ) + + async def _cleanup_expired_files(self): + """Remove expired cache files""" + try: + current_time = time.time() + removed_count = 0 + + for file_path in self.cache_dir.iterdir(): + if file_path.is_file(): + # Check file age + file_age = current_time - file_path.stat().st_mtime + if file_age > self.default_timeout: + try: + file_path.unlink() + removed_count += 1 + except Exception: + pass + + if removed_count > 0: + debug_logger.log_info(f"Cleanup: removed {removed_count} expired cache files") + + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to cleanup expired files: {str(e)}", + status_code=0, + response_text="" + ) + + def _generate_cache_filename(self, url: str, media_type: str) -> str: + """Generate unique filename for cached file""" + # Use URL hash as filename + url_hash = hashlib.md5(url.encode()).hexdigest() + + # Determine file extension + if media_type == "video": + ext = ".mp4" + elif media_type == "image": + ext = ".jpg" + else: + ext = "" + + return f"{url_hash}{ext}" + + async def download_and_cache(self, url: str, media_type: str) -> str: + """ + Download file from URL and cache it locally + + Args: + url: File URL to download + media_type: 'image' or 'video' + + Returns: + Local cache filename + """ + filename = self._generate_cache_filename(url, media_type) + file_path = self.cache_dir / filename + + # Check if already cached and not expired + if file_path.exists(): + file_age = time.time() - file_path.stat().st_mtime + if file_age < self.default_timeout: + debug_logger.log_info(f"Cache hit: {filename}") + return filename + else: + # Remove expired file + try: + file_path.unlink() + except Exception: + pass + + # Download file + debug_logger.log_info(f"Downloading file from: {url}") + + try: + # Get proxy if available + proxy_url = None + if self.proxy_manager: + proxy_config = await self.proxy_manager.get_proxy_config() + if proxy_config and proxy_config.enabled and proxy_config.proxy_url: + proxy_url = proxy_config.proxy_url + + # Download with proxy support + async with AsyncSession() as session: + proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None + response = await session.get(url, timeout=60, proxies=proxies) + + if response.status_code != 200: + raise Exception(f"Download failed: HTTP {response.status_code}") + + # Save to cache + with open(file_path, 'wb') as f: + f.write(response.content) + + debug_logger.log_info(f"File cached: {filename} ({len(response.content)} bytes)") + return filename + + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to download file: {str(e)}", + status_code=0, + response_text=str(e) + ) + raise Exception(f"Failed to cache file: {str(e)}") + + def get_cache_path(self, filename: str) -> Path: + """Get full path to cached file""" + return self.cache_dir / filename + + def set_timeout(self, timeout: int): + """Set cache timeout in seconds""" + self.default_timeout = timeout + debug_logger.log_info(f"Cache timeout updated to {timeout} seconds") + + def get_timeout(self) -> int: + """Get current cache timeout""" + return self.default_timeout + + async def clear_all(self): + """Clear all cached files""" + try: + removed_count = 0 + for file_path in self.cache_dir.iterdir(): + if file_path.is_file(): + try: + file_path.unlink() + removed_count += 1 + except Exception: + pass + + debug_logger.log_info(f"Cache cleared: removed {removed_count} files") + return removed_count + + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to clear cache: {str(e)}", + status_code=0, + response_text="" + ) + raise diff --git a/src/services/flow_client.py b/src/services/flow_client.py new file mode 100644 index 0000000..cb70f4c --- /dev/null +++ b/src/services/flow_client.py @@ -0,0 +1,657 @@ +"""Flow API Client for VideoFX (Veo)""" +import time +import uuid +import random +import base64 +from typing import Dict, Any, Optional, List +from curl_cffi.requests import AsyncSession +from ..core.logger import debug_logger +from ..core.config import config + + +class FlowClient: + """VideoFX API客户端""" + + def __init__(self, proxy_manager): + self.proxy_manager = proxy_manager + self.labs_base_url = config.flow_labs_base_url # https://labs.google/fx/api + self.api_base_url = config.flow_api_base_url # https://aisandbox-pa.googleapis.com/v1 + self.timeout = config.flow_timeout + + async def _make_request( + self, + method: str, + url: str, + headers: Optional[Dict] = None, + json_data: Optional[Dict] = None, + use_st: bool = False, + st_token: Optional[str] = None, + use_at: bool = False, + at_token: Optional[str] = None + ) -> Dict[str, Any]: + """统一HTTP请求处理 + + Args: + method: HTTP方法 (GET/POST) + url: 完整URL + headers: 请求头 + json_data: JSON请求体 + use_st: 是否使用ST认证 (Cookie方式) + st_token: Session Token + use_at: 是否使用AT认证 (Bearer方式) + at_token: Access Token + """ + proxy_url = await self.proxy_manager.get_proxy_url() + + if headers is None: + headers = {} + + # ST认证 - 使用Cookie + if use_st and st_token: + headers["Cookie"] = f"__Secure-next-auth.session-token={st_token}" + + # AT认证 - 使用Bearer + if use_at and at_token: + headers["authorization"] = f"Bearer {at_token}" + + # 通用请求头 + headers.update({ + "Content-Type": "application/json", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" + }) + + # Log request + if config.debug_enabled: + debug_logger.log_request( + method=method, + url=url, + headers=headers, + body=json_data, + proxy=proxy_url + ) + + start_time = time.time() + + try: + async with AsyncSession() as session: + if method.upper() == "GET": + response = await session.get( + url, + headers=headers, + proxy=proxy_url, + timeout=self.timeout, + impersonate="chrome110" + ) + else: # POST + response = await session.post( + url, + headers=headers, + json=json_data, + proxy=proxy_url, + timeout=self.timeout, + impersonate="chrome110" + ) + + duration_ms = (time.time() - start_time) * 1000 + + # Log response + if config.debug_enabled: + debug_logger.log_response( + status_code=response.status_code, + headers=dict(response.headers), + body=response.text, + duration_ms=duration_ms + ) + + response.raise_for_status() + return response.json() + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + error_msg = str(e) + + if config.debug_enabled: + debug_logger.log_error( + error_message=error_msg, + status_code=getattr(e, 'status_code', None), + response_text=getattr(e, 'response_text', None) + ) + + raise Exception(f"Flow API request failed: {error_msg}") + + # ========== 认证相关 (使用ST) ========== + + async def st_to_at(self, st: str) -> dict: + """ST转AT + + Args: + st: Session Token + + Returns: + { + "access_token": "AT", + "expires": "2025-11-15T04:46:04.000Z", + "user": {...} + } + """ + url = f"{self.labs_base_url}/auth/session" + result = await self._make_request( + method="GET", + url=url, + use_st=True, + st_token=st + ) + return result + + # ========== 项目管理 (使用ST) ========== + + async def create_project(self, st: str, title: str) -> str: + """创建项目,返回project_id + + Args: + st: Session Token + title: 项目标题 + + Returns: + project_id (UUID) + """ + url = f"{self.labs_base_url}/trpc/project.createProject" + json_data = { + "json": { + "projectTitle": title, + "toolName": "PINHOLE" + } + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_st=True, + st_token=st + ) + + # 解析返回的project_id + project_id = result["result"]["data"]["json"]["result"]["projectId"] + return project_id + + async def delete_project(self, st: str, project_id: str): + """删除项目 + + Args: + st: Session Token + project_id: 项目ID + """ + url = f"{self.labs_base_url}/trpc/project.deleteProject" + json_data = { + "json": { + "projectToDeleteId": project_id + } + } + + await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_st=True, + st_token=st + ) + + # ========== 余额查询 (使用AT) ========== + + async def get_credits(self, at: str) -> dict: + """查询余额 + + Args: + at: Access Token + + Returns: + { + "credits": 920, + "userPaygateTier": "PAYGATE_TIER_ONE" + } + """ + url = f"{self.api_base_url}/credits" + result = await self._make_request( + method="GET", + url=url, + use_at=True, + at_token=at + ) + return result + + # ========== 图片上传 (使用AT) ========== + + async def upload_image( + self, + at: str, + image_bytes: bytes, + aspect_ratio: str = "IMAGE_ASPECT_RATIO_LANDSCAPE" + ) -> str: + """上传图片,返回mediaGenerationId + + Args: + at: Access Token + image_bytes: 图片字节数据 + aspect_ratio: 图片或视频宽高比(会自动转换为图片格式) + + Returns: + mediaGenerationId (CAM...) + """ + # 转换视频aspect_ratio为图片aspect_ratio + # VIDEO_ASPECT_RATIO_LANDSCAPE -> IMAGE_ASPECT_RATIO_LANDSCAPE + # VIDEO_ASPECT_RATIO_PORTRAIT -> IMAGE_ASPECT_RATIO_PORTRAIT + if aspect_ratio.startswith("VIDEO_"): + aspect_ratio = aspect_ratio.replace("VIDEO_", "IMAGE_") + + # 编码为base64 (去掉前缀) + image_base64 = base64.b64encode(image_bytes).decode('utf-8') + + url = f"{self.api_base_url}:uploadUserImage" + json_data = { + "imageInput": { + "rawImageBytes": image_base64, + "mimeType": "image/jpeg", + "isUserUploaded": True, + "aspectRatio": aspect_ratio + }, + "clientContext": { + "sessionId": self._generate_session_id(), + "tool": "ASSET_MANAGER" + } + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + # 返回mediaGenerationId + media_id = result["mediaGenerationId"]["mediaGenerationId"] + return media_id + + # ========== 图片生成 (使用AT) - 同步返回 ========== + + async def generate_image( + self, + at: str, + project_id: str, + prompt: str, + model_name: str, + aspect_ratio: str, + image_inputs: Optional[List[Dict]] = None + ) -> dict: + """生成图片(同步返回) + + Args: + at: Access Token + project_id: 项目ID + prompt: 提示词 + model_name: GEM_PIX, GEM_PIX_2 或 IMAGEN_3_5 + aspect_ratio: 图片宽高比 + image_inputs: 参考图片列表(图生图时使用) + + Returns: + { + "media": [{ + "image": { + "generatedImage": { + "fifeUrl": "图片URL", + ... + } + } + }] + } + """ + url = f"{self.api_base_url}/projects/{project_id}/flowMedia:batchGenerateImages" + + # 构建请求 + request_data = { + "clientContext": { + "sessionId": self._generate_session_id() + }, + "seed": random.randint(1, 99999), + "imageModelName": model_name, + "imageAspectRatio": aspect_ratio, + "prompt": prompt, + "imageInputs": image_inputs or [] + } + + json_data = { + "requests": [request_data] + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + # ========== 视频生成 (使用AT) - 异步返回 ========== + + async def generate_video_text( + self, + at: str, + project_id: str, + prompt: str, + model_key: str, + aspect_ratio: str, + user_paygate_tier: str = "PAYGATE_TIER_ONE" + ) -> dict: + """文生视频,返回task_id + + Args: + at: Access Token + project_id: 项目ID + prompt: 提示词 + model_key: veo_3_1_t2v_fast 等 + aspect_ratio: 视频宽高比 + user_paygate_tier: 用户等级 + + Returns: + { + "operations": [{ + "operation": {"name": "task_id"}, + "sceneId": "uuid", + "status": "MEDIA_GENERATION_STATUS_PENDING" + }], + "remainingCredits": 900 + } + """ + url = f"{self.api_base_url}/video:batchAsyncGenerateVideoText" + + scene_id = str(uuid.uuid4()) + + json_data = { + "clientContext": { + "sessionId": self._generate_session_id(), + "projectId": project_id, + "tool": "PINHOLE", + "userPaygateTier": user_paygate_tier + }, + "requests": [{ + "aspectRatio": aspect_ratio, + "seed": random.randint(1, 99999), + "textInput": { + "prompt": prompt + }, + "videoModelKey": model_key, + "metadata": { + "sceneId": scene_id + } + }] + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + async def generate_video_reference_images( + self, + at: str, + project_id: str, + prompt: str, + model_key: str, + aspect_ratio: str, + reference_images: List[Dict], + user_paygate_tier: str = "PAYGATE_TIER_ONE" + ) -> dict: + """图生视频,返回task_id + + Args: + at: Access Token + project_id: 项目ID + prompt: 提示词 + model_key: veo_3_0_r2v_fast + aspect_ratio: 视频宽高比 + reference_images: 参考图片列表 [{"imageUsageType": "IMAGE_USAGE_TYPE_ASSET", "mediaId": "..."}] + user_paygate_tier: 用户等级 + + Returns: + 同 generate_video_text + """ + url = f"{self.api_base_url}/video:batchAsyncGenerateVideoReferenceImages" + + scene_id = str(uuid.uuid4()) + + json_data = { + "clientContext": { + "sessionId": self._generate_session_id(), + "projectId": project_id, + "tool": "PINHOLE", + "userPaygateTier": user_paygate_tier + }, + "requests": [{ + "aspectRatio": aspect_ratio, + "seed": random.randint(1, 99999), + "textInput": { + "prompt": prompt + }, + "videoModelKey": model_key, + "referenceImages": reference_images, + "metadata": { + "sceneId": scene_id + } + }] + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + async def generate_video_start_end( + self, + at: str, + project_id: str, + prompt: str, + model_key: str, + aspect_ratio: str, + start_media_id: str, + end_media_id: str, + user_paygate_tier: str = "PAYGATE_TIER_ONE" + ) -> dict: + """收尾帧生成视频,返回task_id + + Args: + at: Access Token + project_id: 项目ID + prompt: 提示词 + model_key: veo_3_1_i2v_s_fast_fl + aspect_ratio: 视频宽高比 + start_media_id: 起始帧mediaId + end_media_id: 结束帧mediaId + user_paygate_tier: 用户等级 + + Returns: + 同 generate_video_text + """ + url = f"{self.api_base_url}/video:batchAsyncGenerateVideoStartAndEndImage" + + scene_id = str(uuid.uuid4()) + + json_data = { + "clientContext": { + "sessionId": self._generate_session_id(), + "projectId": project_id, + "tool": "PINHOLE", + "userPaygateTier": user_paygate_tier + }, + "requests": [{ + "aspectRatio": aspect_ratio, + "seed": random.randint(1, 99999), + "textInput": { + "prompt": prompt + }, + "videoModelKey": model_key, + "startImage": { + "mediaId": start_media_id + }, + "endImage": { + "mediaId": end_media_id + }, + "metadata": { + "sceneId": scene_id + } + }] + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + async def generate_video_start_image( + self, + at: str, + project_id: str, + prompt: str, + model_key: str, + aspect_ratio: str, + start_media_id: str, + user_paygate_tier: str = "PAYGATE_TIER_ONE" + ) -> dict: + """仅首帧生成视频,返回task_id + + Args: + at: Access Token + project_id: 项目ID + prompt: 提示词 + model_key: veo_3_1_i2v_s_fast_fl等 + aspect_ratio: 视频宽高比 + start_media_id: 起始帧mediaId + user_paygate_tier: 用户等级 + + Returns: + 同 generate_video_text + """ + url = f"{self.api_base_url}/video:batchAsyncGenerateVideoStartAndEndImage" + + scene_id = str(uuid.uuid4()) + + json_data = { + "clientContext": { + "sessionId": self._generate_session_id(), + "projectId": project_id, + "tool": "PINHOLE", + "userPaygateTier": user_paygate_tier + }, + "requests": [{ + "aspectRatio": aspect_ratio, + "seed": random.randint(1, 99999), + "textInput": { + "prompt": prompt + }, + "videoModelKey": model_key, + "startImage": { + "mediaId": start_media_id + }, + # 注意: 没有endImage字段,只用首帧 + "metadata": { + "sceneId": scene_id + } + }] + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + # ========== 任务轮询 (使用AT) ========== + + async def check_video_status(self, at: str, operations: List[Dict]) -> dict: + """查询视频生成状态 + + Args: + at: Access Token + operations: 操作列表 [{"operation": {"name": "task_id"}, "sceneId": "...", "status": "..."}] + + Returns: + { + "operations": [{ + "operation": { + "name": "task_id", + "metadata": {...} # 完成时包含视频信息 + }, + "status": "MEDIA_GENERATION_STATUS_SUCCESSFUL" + }] + } + """ + url = f"{self.api_base_url}/video:batchCheckAsyncVideoGenerationStatus" + + json_data = { + "operations": operations + } + + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at + ) + + return result + + # ========== 媒体删除 (使用ST) ========== + + async def delete_media(self, st: str, media_names: List[str]): + """删除媒体 + + Args: + st: Session Token + media_names: 媒体ID列表 + """ + url = f"{self.labs_base_url}/trpc/media.deleteMedia" + json_data = { + "json": { + "names": media_names + } + } + + await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_st=True, + st_token=st + ) + + # ========== 辅助方法 ========== + + def _generate_session_id(self) -> str: + """生成sessionId: ;timestamp""" + return f";{int(time.time() * 1000)}" + + def _generate_scene_id(self) -> str: + """生成sceneId: UUID""" + return str(uuid.uuid4()) diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py new file mode 100644 index 0000000..4d47b5f --- /dev/null +++ b/src/services/generation_handler.py @@ -0,0 +1,850 @@ +"""Generation handler for Flow2API""" +import asyncio +import base64 +import json +import time +from typing import Optional, AsyncGenerator, List, Dict, Any +from ..core.logger import debug_logger +from ..core.config import config +from ..core.models import Task, RequestLog +from .file_cache import FileCache + + +# Model configuration +MODEL_CONFIG = { + # 图片生成 - GEM_PIX (Gemini 2.5 Flash) + "gemini-2.5-flash-image-landscape": { + "type": "image", + "model_name": "GEM_PIX", + "aspect_ratio": "IMAGE_ASPECT_RATIO_LANDSCAPE" + }, + "gemini-2.5-flash-image-portrait": { + "type": "image", + "model_name": "GEM_PIX", + "aspect_ratio": "IMAGE_ASPECT_RATIO_PORTRAIT" + }, + + # 图片生成 - GEM_PIX_2 (Gemini 3.0 Pro) + "gemini-3.0-pro-image-landscape": { + "type": "image", + "model_name": "GEM_PIX_2", + "aspect_ratio": "IMAGE_ASPECT_RATIO_LANDSCAPE" + }, + "gemini-3.0-pro-image-portrait": { + "type": "image", + "model_name": "GEM_PIX_2", + "aspect_ratio": "IMAGE_ASPECT_RATIO_PORTRAIT" + }, + + # 图片生成 - IMAGEN_3_5 (Imagen 4.0) + "imagen-4.0-generate-preview-landscape": { + "type": "image", + "model_name": "IMAGEN_3_5", + "aspect_ratio": "IMAGE_ASPECT_RATIO_LANDSCAPE" + }, + "imagen-4.0-generate-preview-portrait": { + "type": "image", + "model_name": "IMAGEN_3_5", + "aspect_ratio": "IMAGE_ASPECT_RATIO_PORTRAIT" + }, + + # ========== 文生视频 (T2V - Text to Video) ========== + # 不支持上传图片,只使用文本提示词生成 + + # veo_3_1_t2v_fast_portrait (竖屏) + # 上游模型名: veo_3_1_t2v_fast_portrait + "veo_3_1_t2v_fast_portrait": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_3_1_t2v_fast_portrait", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": False + }, + # veo_3_1_t2v_fast_landscape (横屏) + # 上游模型名: veo_3_1_t2v_fast + "veo_3_1_t2v_fast_landscape": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_3_1_t2v_fast", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": False + }, + + # veo_2_1_fast_d_15_t2v (需要新增横竖屏) + "veo_2_1_fast_d_15_t2v_portrait": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_2_1_fast_d_15_t2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": False + }, + "veo_2_1_fast_d_15_t2v_landscape": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_2_1_fast_d_15_t2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": False + }, + + # veo_2_0_t2v (需要新增横竖屏) + "veo_2_0_t2v_portrait": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_2_0_t2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": False + }, + "veo_2_0_t2v_landscape": { + "type": "video", + "video_type": "t2v", + "model_key": "veo_2_0_t2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": False + }, + + # ========== 首尾帧模型 (I2V - Image to Video) ========== + # 支持1-2张图片:1张作为首帧,2张作为首尾帧 + + # veo_3_1_i2v_s_fast_fl (需要新增横竖屏) + "veo_3_1_i2v_s_fast_fl_portrait": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_3_1_i2v_s_fast_fl", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + "veo_3_1_i2v_s_fast_fl_landscape": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_3_1_i2v_s_fast_fl", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + + # veo_2_1_fast_d_15_i2v (需要新增横竖屏) + "veo_2_1_fast_d_15_i2v_portrait": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_2_1_fast_d_15_i2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + "veo_2_1_fast_d_15_i2v_landscape": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_2_1_fast_d_15_i2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + + # veo_2_0_i2v (需要新增横竖屏) + "veo_2_0_i2v_portrait": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_2_0_i2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + "veo_2_0_i2v_landscape": { + "type": "video", + "video_type": "i2v", + "model_key": "veo_2_0_i2v", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": True, + "min_images": 1, + "max_images": 2 + }, + + # ========== 多图生成 (R2V - Reference Images to Video) ========== + # 支持多张图片,不限制数量 + + # veo_3_0_r2v_fast (需要新增横竖屏) + "veo_3_0_r2v_fast_portrait": { + "type": "video", + "video_type": "r2v", + "model_key": "veo_3_0_r2v_fast", + "aspect_ratio": "VIDEO_ASPECT_RATIO_PORTRAIT", + "supports_images": True, + "min_images": 0, + "max_images": None # 不限制 + }, + "veo_3_0_r2v_fast_landscape": { + "type": "video", + "video_type": "r2v", + "model_key": "veo_3_0_r2v_fast", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "supports_images": True, + "min_images": 0, + "max_images": None # 不限制 + } +} + + +class GenerationHandler: + """统一生成处理器""" + + def __init__(self, flow_client, token_manager, load_balancer, db, concurrency_manager, proxy_manager): + self.flow_client = flow_client + self.token_manager = token_manager + self.load_balancer = load_balancer + self.db = db + self.concurrency_manager = concurrency_manager + self.file_cache = FileCache( + cache_dir="tmp", + default_timeout=config.cache_timeout, + proxy_manager=proxy_manager + ) + + async def check_token_availability(self, is_image: bool, is_video: bool) -> bool: + """检查Token可用性 + + Args: + is_image: 是否检查图片生成Token + is_video: 是否检查视频生成Token + + Returns: + True表示有可用Token, False表示无可用Token + """ + token_obj = await self.load_balancer.select_token( + for_image_generation=is_image, + for_video_generation=is_video + ) + return token_obj is not None + + async def handle_generation( + self, + model: str, + prompt: str, + images: Optional[List[bytes]] = None, + stream: bool = False + ) -> AsyncGenerator: + """统一生成入口 + + Args: + model: 模型名称 + prompt: 提示词 + images: 图片列表 (bytes格式) + stream: 是否流式输出 + """ + start_time = time.time() + token = None + + # 1. 验证模型 + if model not in MODEL_CONFIG: + error_msg = f"不支持的模型: {model}" + debug_logger.log_error(error_msg) + yield self._create_error_response(error_msg) + return + + model_config = MODEL_CONFIG[model] + generation_type = model_config["type"] + debug_logger.log_info(f"[GENERATION] 开始生成 - 模型: {model}, 类型: {generation_type}, Prompt: {prompt[:50]}...") + + # 非流式模式: 只检查可用性 + if not stream: + is_image = (generation_type == "image") + is_video = (generation_type == "video") + available = await self.check_token_availability(is_image, is_video) + + if available: + if is_image: + message = "所有Token可用于图片生成。请启用流式模式使用生成功能。" + else: + message = "所有Token可用于视频生成。请启用流式模式使用生成功能。" + else: + if is_image: + message = "没有可用的Token进行图片生成" + else: + message = "没有可用的Token进行视频生成" + + yield self._create_completion_response(message, is_availability_check=True) + return + + # 向用户展示开始信息 + if stream: + yield self._create_stream_chunk( + f"✨ {'视频' if generation_type == 'video' else '图片'}生成任务已启动\n", + role="assistant" + ) + + # 2. 选择Token + debug_logger.log_info(f"[GENERATION] 正在选择可用Token...") + + if generation_type == "image": + token = await self.load_balancer.select_token(for_image_generation=True) + else: + token = await self.load_balancer.select_token(for_video_generation=True) + + if not token: + error_msg = self._get_no_token_error_message(generation_type) + debug_logger.log_error(f"[GENERATION] {error_msg}") + if stream: + yield self._create_stream_chunk(f"❌ {error_msg}\n") + yield self._create_error_response(error_msg) + return + + debug_logger.log_info(f"[GENERATION] 已选择Token: {token.id} ({token.email})") + + try: + # 3. 确保AT有效 + debug_logger.log_info(f"[GENERATION] 检查Token AT有效性...") + if stream: + yield self._create_stream_chunk("初始化生成环境...\n") + + if not await self.token_manager.is_at_valid(token.id): + error_msg = "Token AT无效或刷新失败" + debug_logger.log_error(f"[GENERATION] {error_msg}") + if stream: + yield self._create_stream_chunk(f"❌ {error_msg}\n") + yield self._create_error_response(error_msg) + return + + # 重新获取token (AT可能已刷新) + token = await self.token_manager.get_token(token.id) + + # 4. 确保Project存在 + debug_logger.log_info(f"[GENERATION] 检查/创建Project...") + + project_id = await self.token_manager.ensure_project_exists(token.id) + debug_logger.log_info(f"[GENERATION] Project ID: {project_id}") + + # 5. 根据类型处理 + if generation_type == "image": + debug_logger.log_info(f"[GENERATION] 开始图片生成流程...") + async for chunk in self._handle_image_generation( + token, project_id, model_config, prompt, images, stream + ): + yield chunk + else: # video + debug_logger.log_info(f"[GENERATION] 开始视频生成流程...") + async for chunk in self._handle_video_generation( + token, project_id, model_config, prompt, images, stream + ): + yield chunk + + # 6. 记录使用 + is_video = (generation_type == "video") + await self.token_manager.record_usage(token.id, is_video=is_video) + debug_logger.log_info(f"[GENERATION] ✅ 生成成功完成") + + # 7. 记录成功日志 + duration = time.time() - start_time + await self._log_request( + token.id, + f"generate_{generation_type}", + {"model": model, "prompt": prompt[:100], "has_images": images is not None and len(images) > 0}, + {"status": "success"}, + 200, + duration + ) + + except Exception as e: + error_msg = f"生成失败: {str(e)}" + debug_logger.log_error(f"[GENERATION] ❌ {error_msg}") + if stream: + yield self._create_stream_chunk(f"❌ {error_msg}\n") + if token: + await self.token_manager.record_error(token.id) + yield self._create_error_response(error_msg) + + # 记录失败日志 + duration = time.time() - start_time + await self._log_request( + token.id if token else None, + f"generate_{generation_type if model_config else 'unknown'}", + {"model": model, "prompt": prompt[:100], "has_images": images is not None and len(images) > 0}, + {"error": error_msg}, + 500, + duration + ) + + def _get_no_token_error_message(self, generation_type: str) -> str: + """获取无可用Token时的详细错误信息""" + if generation_type == "image": + return "没有可用的Token进行图片生成。所有Token都处于禁用、冷却、锁定或已过期状态。" + else: + return "没有可用的Token进行视频生成。所有Token都处于禁用、冷却、配额耗尽或已过期状态。" + + async def _handle_image_generation( + self, + token, + project_id: str, + model_config: dict, + prompt: str, + images: Optional[List[bytes]], + stream: bool + ) -> AsyncGenerator: + """处理图片生成 (同步返回)""" + + # 获取并发槽位 + if self.concurrency_manager: + if not await self.concurrency_manager.acquire_image(token.id): + yield self._create_error_response("图片并发限制已达上限") + return + + try: + # 上传图片 (如果有) + image_inputs = [] + if images and len(images) > 0: + if stream: + yield self._create_stream_chunk("上传参考图片...\n") + + image_bytes = images[0] # 图生图只需要一张 + media_id = await self.flow_client.upload_image( + token.at, + image_bytes, + model_config["aspect_ratio"] + ) + + image_inputs = [{ + "name": media_id, + "imageInputType": "IMAGE_INPUT_TYPE_REFERENCE" + }] + + # 调用生成API + if stream: + yield self._create_stream_chunk("正在生成图片...\n") + + result = await self.flow_client.generate_image( + at=token.at, + project_id=project_id, + prompt=prompt, + model_name=model_config["model_name"], + aspect_ratio=model_config["aspect_ratio"], + image_inputs=image_inputs + ) + + # 提取URL + media = result.get("media", []) + if not media: + yield self._create_error_response("生成结果为空") + return + + image_url = media[0]["image"]["generatedImage"]["fifeUrl"] + + # 缓存图片 (如果启用) + local_url = image_url + if config.cache_enabled: + try: + if stream: + yield self._create_stream_chunk("缓存图片中...\n") + cached_filename = await self.file_cache.download_and_cache(image_url, "image") + local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + except Exception as e: + debug_logger.log_error(f"Failed to cache image: {str(e)}") + # 缓存失败不影响结果返回,使用原始URL + local_url = image_url + + # 返回结果 + if stream: + yield self._create_stream_chunk( + f"", + finish_reason="stop" + ) + else: + yield self._create_completion_response( + local_url, # 直接传URL,让方法内部格式化 + media_type="image" + ) + + finally: + # 释放并发槽位 + if self.concurrency_manager: + await self.concurrency_manager.release_image(token.id) + + async def _handle_video_generation( + self, + token, + project_id: str, + model_config: dict, + prompt: str, + images: Optional[List[bytes]], + stream: bool + ) -> AsyncGenerator: + """处理视频生成 (异步轮询)""" + + # 获取并发槽位 + if self.concurrency_manager: + if not await self.concurrency_manager.acquire_video(token.id): + yield self._create_error_response("视频并发限制已达上限") + return + + try: + # 获取模型类型和配置 + video_type = model_config.get("video_type") + supports_images = model_config.get("supports_images", False) + min_images = model_config.get("min_images", 0) + max_images = model_config.get("max_images", 0) + + # 图片数量 + image_count = len(images) if images else 0 + + # ========== 验证和处理图片 ========== + + # T2V: 文生视频 - 不支持图片 + if video_type == "t2v": + if image_count > 0: + if stream: + yield self._create_stream_chunk("⚠️ 文生视频模型不支持上传图片,将忽略图片仅使用文本提示词生成\n") + debug_logger.log_warning(f"[T2V] 模型 {model_config['model_key']} 不支持图片,已忽略 {image_count} 张图片") + images = None # 清空图片 + image_count = 0 + + # I2V: 首尾帧模型 - 需要1-2张图片 + elif video_type == "i2v": + if image_count < min_images or image_count > max_images: + error_msg = f"❌ 首尾帧模型需要 {min_images}-{max_images} 张图片,当前提供了 {image_count} 张" + if stream: + yield self._create_stream_chunk(f"{error_msg}\n") + yield self._create_error_response(error_msg) + return + + # R2V: 多图生成 - 支持多张图片,不限制数量 + elif video_type == "r2v": + # 不再限制最大图片数量 + pass + + # ========== 上传图片 ========== + start_media_id = None + end_media_id = None + reference_images = [] + + # I2V: 首尾帧处理 + if video_type == "i2v" and images: + if image_count == 1: + # 只有1张图: 仅作为首帧 + if stream: + yield self._create_stream_chunk("上传首帧图片...\n") + start_media_id = await self.flow_client.upload_image( + token.at, images[0], model_config["aspect_ratio"] + ) + debug_logger.log_info(f"[I2V] 仅上传首帧: {start_media_id}") + + elif image_count == 2: + # 2张图: 首帧+尾帧 + if stream: + yield self._create_stream_chunk("上传首帧和尾帧图片...\n") + start_media_id = await self.flow_client.upload_image( + token.at, images[0], model_config["aspect_ratio"] + ) + end_media_id = await self.flow_client.upload_image( + token.at, images[1], model_config["aspect_ratio"] + ) + debug_logger.log_info(f"[I2V] 上传首尾帧: {start_media_id}, {end_media_id}") + + # R2V: 多图处理 + elif video_type == "r2v" and images: + if stream: + yield self._create_stream_chunk(f"上传 {image_count} 张参考图片...\n") + + for idx, img in enumerate(images): # 上传所有图片,不限制数量 + media_id = await self.flow_client.upload_image( + token.at, img, model_config["aspect_ratio"] + ) + reference_images.append({ + "imageUsageType": "IMAGE_USAGE_TYPE_ASSET", + "mediaId": media_id + }) + debug_logger.log_info(f"[R2V] 上传了 {len(reference_images)} 张参考图片") + + # ========== 调用生成API ========== + if stream: + yield self._create_stream_chunk("提交视频生成任务...\n") + + # I2V: 首尾帧生成 + if video_type == "i2v" and start_media_id: + if end_media_id: + # 有首尾帧 + result = await self.flow_client.generate_video_start_end( + at=token.at, + project_id=project_id, + prompt=prompt, + model_key=model_config["model_key"], + aspect_ratio=model_config["aspect_ratio"], + start_media_id=start_media_id, + end_media_id=end_media_id, + user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE" + ) + else: + # 只有首帧 + result = await self.flow_client.generate_video_start_image( + at=token.at, + project_id=project_id, + prompt=prompt, + model_key=model_config["model_key"], + aspect_ratio=model_config["aspect_ratio"], + start_media_id=start_media_id, + user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE" + ) + + # R2V: 多图生成 + elif video_type == "r2v" and reference_images: + result = await self.flow_client.generate_video_reference_images( + at=token.at, + project_id=project_id, + prompt=prompt, + model_key=model_config["model_key"], + aspect_ratio=model_config["aspect_ratio"], + reference_images=reference_images, + user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE" + ) + + # T2V 或 R2V无图: 纯文本生成 + else: + result = await self.flow_client.generate_video_text( + at=token.at, + project_id=project_id, + prompt=prompt, + model_key=model_config["model_key"], + aspect_ratio=model_config["aspect_ratio"], + user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE" + ) + + # 获取task_id和operations + operations = result.get("operations", []) + if not operations: + yield self._create_error_response("生成任务创建失败") + return + + operation = operations[0] + task_id = operation["operation"]["name"] + scene_id = operation.get("sceneId") + + # 保存Task到数据库 + task = Task( + task_id=task_id, + token_id=token.id, + model=model_config["model_key"], + prompt=prompt, + status="processing", + scene_id=scene_id + ) + await self.db.create_task(task) + + # 轮询结果 + if stream: + yield self._create_stream_chunk(f"视频生成中...\n") + + async for chunk in self._poll_video_result(token, operations, stream): + yield chunk + + finally: + # 释放并发槽位 + if self.concurrency_manager: + await self.concurrency_manager.release_video(token.id) + + async def _poll_video_result( + self, + token, + operations: List[Dict], + stream: bool + ) -> AsyncGenerator: + """轮询视频生成结果""" + + max_attempts = config.max_poll_attempts + poll_interval = config.poll_interval + + for attempt in range(max_attempts): + await asyncio.sleep(poll_interval) + + try: + result = await self.flow_client.check_video_status(token.at, operations) + checked_operations = result.get("operations", []) + + if not checked_operations: + continue + + operation = checked_operations[0] + status = operation.get("status") + + # 状态更新 - 每20秒报告一次 (poll_interval=3秒, 20秒约7次轮询) + progress_update_interval = 7 # 每7次轮询 = 21秒 + if stream and attempt % progress_update_interval == 0: # 每20秒报告一次 + progress = min(int((attempt / max_attempts) * 100), 95) + yield self._create_stream_chunk(f"生成进度: {progress}%\n") + + # 检查状态 + if status == "MEDIA_GENERATION_STATUS_SUCCESSFUL": + # 成功 + metadata = operation["operation"].get("metadata", {}) + video_info = metadata.get("video", {}) + video_url = video_info.get("fifeUrl") + + if not video_url: + yield self._create_error_response("视频URL为空") + return + + # 缓存视频 (如果启用) + local_url = video_url + if config.cache_enabled: + try: + if stream: + yield self._create_stream_chunk("缓存视频中...\n") + cached_filename = await self.file_cache.download_and_cache(video_url, "video") + local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + except Exception as e: + debug_logger.log_error(f"Failed to cache video: {str(e)}") + # 缓存失败不影响结果返回,使用原始URL + local_url = video_url + + # 更新数据库 + task_id = operation["operation"]["name"] + await self.db.update_task( + task_id, + status="completed", + progress=100, + result_urls=[local_url], + completed_at=time.time() + ) + + # 返回结果 + if stream: + yield self._create_stream_chunk( + f"", + finish_reason="stop" + ) + else: + yield self._create_completion_response( + local_url, # 直接传URL,让方法内部格式化 + media_type="video" + ) + return + + elif status.startswith("MEDIA_GENERATION_STATUS_ERROR"): + # 失败 + yield self._create_error_response(f"视频生成失败: {status}") + return + + except Exception as e: + debug_logger.log_error(f"Poll error: {str(e)}") + continue + + # 超时 + yield self._create_error_response(f"视频生成超时 (已轮询{max_attempts}次)") + + # ========== 响应格式化 ========== + + def _create_stream_chunk(self, content: str, role: str = None, finish_reason: str = None) -> str: + """创建流式响应chunk""" + import json + import time + + chunk = { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "flow2api", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": finish_reason + }] + } + + if role: + chunk["choices"][0]["delta"]["role"] = role + + if finish_reason: + chunk["choices"][0]["delta"]["content"] = content + else: + chunk["choices"][0]["delta"]["reasoning_content"] = content + + return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + + def _create_completion_response(self, content: str, media_type: str = "image", is_availability_check: bool = False) -> str: + """创建非流式响应 + + Args: + content: 媒体URL或纯文本消息 + media_type: 媒体类型 ("image" 或 "video") + is_availability_check: 是否为可用性检查响应 (纯文本消息) + + Returns: + JSON格式的响应 + """ + import json + import time + + # 可用性检查: 返回纯文本消息 + if is_availability_check: + formatted_content = content + else: + # 媒体生成: 根据媒体类型格式化内容为Markdown + if media_type == "video": + formatted_content = f"```html\n\n```" + else: # image + formatted_content = f"![Generated Image]({content})" + + response = { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion", + "created": int(time.time()), + "model": "flow2api", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": formatted_content + }, + "finish_reason": "stop" + }] + } + + return json.dumps(response, ensure_ascii=False) + + def _create_error_response(self, error_message: str) -> str: + """创建错误响应""" + import json + + error = { + "error": { + "message": error_message, + "type": "invalid_request_error", + "code": "generation_failed" + } + } + + return json.dumps(error, ensure_ascii=False) + + def _get_base_url(self) -> str: + """获取基础URL用于缓存文件访问""" + # 优先使用配置的cache_base_url + if config.cache_base_url: + return config.cache_base_url + # 否则使用服务器地址 + return f"http://{config.server_host}:{config.server_port}" + + async def _log_request( + self, + token_id: Optional[int], + operation: str, + request_data: Dict[str, Any], + response_data: Dict[str, Any], + status_code: int, + duration: float + ): + """记录请求到数据库""" + try: + log = RequestLog( + token_id=token_id, + operation=operation, + request_body=json.dumps(request_data, ensure_ascii=False), + response_body=json.dumps(response_data, ensure_ascii=False), + status_code=status_code, + duration=duration + ) + await self.db.add_request_log(log) + except Exception as e: + # 日志记录失败不影响主流程 + debug_logger.log_error(f"Failed to log request: {e}") + diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py new file mode 100644 index 0000000..ff043d0 --- /dev/null +++ b/src/services/load_balancer.py @@ -0,0 +1,87 @@ +"""Load balancing module for Flow2API""" +import random +from typing import Optional +from ..core.models import Token +from .concurrency_manager import ConcurrencyManager +from ..core.logger import debug_logger + + +class LoadBalancer: + """Token load balancer with random selection""" + + def __init__(self, token_manager, concurrency_manager: Optional[ConcurrencyManager] = None): + self.token_manager = token_manager + self.concurrency_manager = concurrency_manager + + async def select_token( + self, + for_image_generation: bool = False, + for_video_generation: bool = False + ) -> Optional[Token]: + """ + Select a token using random load balancing + + Args: + for_image_generation: If True, only select tokens with image_enabled=True + for_video_generation: If True, only select tokens with video_enabled=True + + Returns: + Selected token or None if no available tokens + """ + debug_logger.log_info(f"[LOAD_BALANCER] 开始选择Token (图片生成={for_image_generation}, 视频生成={for_video_generation})") + + active_tokens = await self.token_manager.get_active_tokens() + debug_logger.log_info(f"[LOAD_BALANCER] 获取到 {len(active_tokens)} 个活跃Token") + + if not active_tokens: + debug_logger.log_info(f"[LOAD_BALANCER] ❌ 没有活跃的Token") + return None + + # Filter tokens based on generation type + available_tokens = [] + filtered_reasons = {} # 记录过滤原因 + + for token in active_tokens: + # Check if token has valid AT (not expired) + if not await self.token_manager.is_at_valid(token.id): + filtered_reasons[token.id] = "AT无效或已过期" + continue + + # Filter for image generation + if for_image_generation: + if not token.image_enabled: + filtered_reasons[token.id] = "图片生成已禁用" + continue + + # Check concurrency limit + if self.concurrency_manager and not await self.concurrency_manager.can_use_image(token.id): + filtered_reasons[token.id] = "图片并发已满" + continue + + # Filter for video generation + if for_video_generation: + if not token.video_enabled: + filtered_reasons[token.id] = "视频生成已禁用" + continue + + # Check concurrency limit + if self.concurrency_manager and not await self.concurrency_manager.can_use_video(token.id): + filtered_reasons[token.id] = "视频并发已满" + continue + + available_tokens.append(token) + + # 输出过滤信息 + if filtered_reasons: + debug_logger.log_info(f"[LOAD_BALANCER] 已过滤Token:") + for token_id, reason in filtered_reasons.items(): + debug_logger.log_info(f"[LOAD_BALANCER] - Token {token_id}: {reason}") + + if not available_tokens: + debug_logger.log_info(f"[LOAD_BALANCER] ❌ 没有可用的Token (图片生成={for_image_generation}, 视频生成={for_video_generation})") + return None + + # Random selection + selected = random.choice(available_tokens) + debug_logger.log_info(f"[LOAD_BALANCER] ✅ 已选择Token {selected.id} ({selected.email}) - 余额: {selected.credits}") + return selected diff --git a/src/services/proxy_manager.py b/src/services/proxy_manager.py new file mode 100644 index 0000000..eaa6535 --- /dev/null +++ b/src/services/proxy_manager.py @@ -0,0 +1,25 @@ +"""Proxy management module""" +from typing import Optional +from ..core.database import Database +from ..core.models import ProxyConfig + +class ProxyManager: + """Proxy configuration manager""" + + def __init__(self, db: Database): + self.db = db + + async def get_proxy_url(self) -> Optional[str]: + """Get proxy URL if enabled, otherwise return None""" + config = await self.db.get_proxy_config() + if config and config.enabled and config.proxy_url: + return config.proxy_url + return None + + async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): + """Update proxy configuration""" + await self.db.update_proxy_config(enabled, proxy_url) + + async def get_proxy_config(self) -> ProxyConfig: + """Get proxy configuration""" + return await self.db.get_proxy_config() diff --git a/src/services/token_manager.py b/src/services/token_manager.py new file mode 100644 index 0000000..afb8dbc --- /dev/null +++ b/src/services/token_manager.py @@ -0,0 +1,384 @@ +"""Token manager for Flow2API with AT auto-refresh""" +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Optional, List +from ..core.database import Database +from ..core.models import Token, Project +from ..core.logger import debug_logger +from .flow_client import FlowClient +from .proxy_manager import ProxyManager + + +class TokenManager: + """Token lifecycle manager with AT auto-refresh""" + + def __init__(self, db: Database, flow_client: FlowClient): + self.db = db + self.flow_client = flow_client + self._lock = asyncio.Lock() + + # ========== Token CRUD ========== + + async def get_all_tokens(self) -> List[Token]: + """Get all tokens""" + return await self.db.get_all_tokens() + + async def get_active_tokens(self) -> List[Token]: + """Get all active tokens""" + return await self.db.get_active_tokens() + + async def get_token(self, token_id: int) -> Optional[Token]: + """Get token by ID""" + return await self.db.get_token(token_id) + + async def delete_token(self, token_id: int): + """Delete token""" + await self.db.delete_token(token_id) + + async def enable_token(self, token_id: int): + """Enable a token""" + await self.db.update_token(token_id, is_active=True) + + async def disable_token(self, token_id: int): + """Disable a token""" + await self.db.update_token(token_id, is_active=False) + + # ========== Token添加 (支持Project创建) ========== + + async def add_token( + self, + st: str, + project_id: Optional[str] = None, + project_name: Optional[str] = None, + remark: Optional[str] = None, + image_enabled: bool = True, + video_enabled: bool = True, + image_concurrency: int = -1, + video_concurrency: int = -1 + ) -> Token: + """Add a new token + + Args: + st: Session Token (必需) + project_id: 项目ID (可选,如果提供则直接使用,不创建新项目) + project_name: 项目名称 (可选,如果不提供则自动生成) + remark: 备注 + image_enabled: 是否启用图片生成 + video_enabled: 是否启用视频生成 + image_concurrency: 图片并发限制 + video_concurrency: 视频并发限制 + + Returns: + Token object + """ + # Step 1: 检查ST是否已存在 + existing_token = await self.db.get_token_by_st(st) + if existing_token: + raise ValueError(f"Token 已存在(邮箱: {existing_token.email})") + + # Step 2: 使用ST转换AT + debug_logger.log_info(f"[ADD_TOKEN] Converting ST to AT...") + try: + result = await self.flow_client.st_to_at(st) + at = result["access_token"] + expires = result.get("expires") + user_info = result.get("user", {}) + email = user_info.get("email", "") + name = user_info.get("name", email.split("@")[0] if email else "") + + # 解析过期时间 + at_expires = None + if expires: + try: + at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) + except: + pass + + except Exception as e: + raise ValueError(f"ST转AT失败: {str(e)}") + + # Step 3: 查询余额 + try: + credits_result = await self.flow_client.get_credits(at) + credits = credits_result.get("credits", 0) + user_paygate_tier = credits_result.get("userPaygateTier") + except: + credits = 0 + user_paygate_tier = None + + # Step 4: 处理Project ID和名称 + if project_id: + # 用户提供了project_id,直接使用 + debug_logger.log_info(f"[ADD_TOKEN] Using provided project_id: {project_id}") + if not project_name: + # 如果没有提供project_name,生成一个 + now = datetime.now() + project_name = now.strftime("%b %d - %H:%M") + else: + # 用户没有提供project_id,需要创建新项目 + if not project_name: + # 自动生成项目名称 + now = datetime.now() + project_name = now.strftime("%b %d - %H:%M") + + try: + project_id = await self.flow_client.create_project(st, project_name) + debug_logger.log_info(f"[ADD_TOKEN] Created new project: {project_name} (ID: {project_id})") + except Exception as e: + raise ValueError(f"创建项目失败: {str(e)}") + + # Step 5: 创建Token对象 + token = Token( + st=st, + at=at, + at_expires=at_expires, + email=email, + name=name, + remark=remark, + is_active=True, + credits=credits, + user_paygate_tier=user_paygate_tier, + current_project_id=project_id, + current_project_name=project_name, + image_enabled=image_enabled, + video_enabled=video_enabled, + image_concurrency=image_concurrency, + video_concurrency=video_concurrency + ) + + # Step 6: 保存到数据库 + token_id = await self.db.add_token(token) + token.id = token_id + + # Step 7: 保存Project到数据库 + project = Project( + project_id=project_id, + token_id=token_id, + project_name=project_name, + tool_name="PINHOLE" + ) + await self.db.add_project(project) + + debug_logger.log_info(f"[ADD_TOKEN] Token added successfully (ID: {token_id}, Email: {email})") + return token + + async def update_token( + self, + token_id: int, + st: Optional[str] = None, + at: Optional[str] = None, + project_id: Optional[str] = None, + project_name: Optional[str] = None, + remark: Optional[str] = None, + image_enabled: Optional[bool] = None, + video_enabled: Optional[bool] = None, + image_concurrency: Optional[int] = None, + video_concurrency: Optional[int] = None + ): + """Update token (支持修改project_id和project_name)""" + update_fields = {} + + if st is not None: + update_fields["st"] = st + if at is not None: + update_fields["at"] = at + if project_id is not None: + update_fields["current_project_id"] = project_id + if project_name is not None: + update_fields["current_project_name"] = project_name + if remark is not None: + update_fields["remark"] = remark + if image_enabled is not None: + update_fields["image_enabled"] = image_enabled + if video_enabled is not None: + update_fields["video_enabled"] = video_enabled + if image_concurrency is not None: + update_fields["image_concurrency"] = image_concurrency + if video_concurrency is not None: + update_fields["video_concurrency"] = video_concurrency + + if update_fields: + await self.db.update_token(token_id, **update_fields) + + # ========== AT自动刷新逻辑 (核心) ========== + + async def is_at_valid(self, token_id: int) -> bool: + """检查AT是否有效,如果无效或即将过期则自动刷新 + + Returns: + True if AT is valid or refreshed successfully + False if AT cannot be refreshed + """ + token = await self.db.get_token(token_id) + if not token: + return False + + # 如果AT不存在,需要刷新 + if not token.at: + debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新") + return await self._refresh_at(token_id) + + # 如果没有过期时间,假设需要刷新 + if not token.at_expires: + debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新") + return await self._refresh_at(token_id) + + # 检查是否即将过期 (提前1小时刷新) + now = datetime.now(timezone.utc) + # 确保at_expires也是timezone-aware + if token.at_expires.tzinfo is None: + at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) + else: + at_expires_aware = token.at_expires + + time_until_expiry = at_expires_aware - now + + if time_until_expiry.total_seconds() < 3600: # 1 hour (3600 seconds) + debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新") + return await self._refresh_at(token_id) + + # AT有效 + return True + + async def _refresh_at(self, token_id: int) -> bool: + """内部方法: 刷新AT + + Returns: + True if refresh successful, False otherwise + """ + async with self._lock: + token = await self.db.get_token(token_id) + if not token: + return False + + try: + debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 开始刷新AT...") + + # 使用ST转AT + result = await self.flow_client.st_to_at(token.st) + new_at = result["access_token"] + expires = result.get("expires") + + # 解析过期时间 + new_at_expires = None + if expires: + try: + new_at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) + except: + pass + + # 更新数据库 + await self.db.update_token( + token_id, + at=new_at, + at_expires=new_at_expires + ) + + debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT刷新成功") + debug_logger.log_info(f" - 新过期时间: {new_at_expires}") + + # 同时刷新credits + try: + credits_result = await self.flow_client.get_credits(new_at) + await self.db.update_token( + token_id, + credits=credits_result.get("credits", 0) + ) + except: + pass + + return True + + except Exception as e: + debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: AT刷新失败 - {str(e)}") + # 刷新失败,禁用Token + await self.disable_token(token_id) + return False + + async def ensure_project_exists(self, token_id: int) -> str: + """确保Token有可用的Project + + Returns: + project_id + """ + token = await self.db.get_token(token_id) + if not token: + raise ValueError("Token not found") + + # 如果已有project_id,直接返回 + if token.current_project_id: + return token.current_project_id + + # 创建新Project + now = datetime.now() + project_name = now.strftime("%b %d - %H:%M") + + try: + project_id = await self.flow_client.create_project(token.st, project_name) + debug_logger.log_info(f"[PROJECT] Created project for token {token_id}: {project_name}") + + # 更新Token + await self.db.update_token( + token_id, + current_project_id=project_id, + current_project_name=project_name + ) + + # 保存Project到数据库 + project = Project( + project_id=project_id, + token_id=token_id, + project_name=project_name + ) + await self.db.add_project(project) + + return project_id + + except Exception as e: + raise ValueError(f"Failed to create project: {str(e)}") + + # ========== Token使用统计 ========== + + async def record_usage(self, token_id: int, is_video: bool = False): + """Record token usage""" + await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now()) + + if is_video: + await self.db.increment_token_stats(token_id, "video") + else: + await self.db.increment_token_stats(token_id, "image") + + async def record_error(self, token_id: int): + """Record token error""" + await self.db.increment_token_stats(token_id, "error") + + # ========== 余额刷新 ========== + + async def refresh_credits(self, token_id: int) -> int: + """刷新Token余额 + + Returns: + credits + """ + token = await self.db.get_token(token_id) + if not token: + return 0 + + # 确保AT有效 + if not await self.is_at_valid(token_id): + return 0 + + # 重新获取token (AT可能已刷新) + token = await self.db.get_token(token_id) + + try: + result = await self.flow_client.get_credits(token.at) + credits = result.get("credits", 0) + + # 更新数据库 + await self.db.update_token(token_id, credits=credits) + + return credits + except Exception as e: + debug_logger.log_error(f"Failed to refresh credits for token {token_id}: {str(e)}") + return 0 diff --git a/static/login.html b/static/login.html new file mode 100644 index 0000000..6d85839 --- /dev/null +++ b/static/login.html @@ -0,0 +1,53 @@ + + + + + + 登录 - Flow2API + + + + + +
+
+
+

Flow2API

+

管理员控制台

+
+
+ +
+
+
+
+ + +
+
+ + +
+ +
+ +
+

Flow2API © 2025

+
+
+
+
+ + + + diff --git a/static/manage.html b/static/manage.html new file mode 100644 index 0000000..d9f1753 --- /dev/null +++ b/static/manage.html @@ -0,0 +1,586 @@ + + + + + + 管理控制台 - Flow2API + + + + + + +
+
+
+ Flow2API +
+
+ + + + + + +
+
+
+ +
+ +
+ +
+ + +
+ +
+
+

Token 总数

+

-

+
+
+

活跃 Token

+

-

+
+
+

今日图片/总图片

+

-

+
+
+

今日视频/总视频

+

-

+
+
+

今日错误/总错误

+

-

+
+
+ + +
+
+

Token 列表

+
+ +
+ 自动刷新AT +
+ + +
+ Token距离过期<1h时自动使用ST刷新AT +
+
+
+
+ + + + +
+
+ +
+ + + + + + + + + + + + + + + + + + + +
邮箱状态过期时间余额项目名称项目ID图片视频错误备注操作
+
+
+
+ + + + + + + + + +
+ + + + + + + + + + + + +