diff --git a/src/api/admin.py b/src/api/admin.py index 34f0648..3e0fa96 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -391,39 +391,34 @@ async def change_password( @router.get("/api/tokens") async def get_tokens(token: str = Depends(verify_admin_token)): """Get all tokens with statistics""" - tokens = await token_manager.get_all_tokens() - result = [] + token_rows = await db.get_all_tokens_with_stats() + to_iso = lambda value: value.isoformat() if hasattr(value, "isoformat") else value - for t in tokens: - stats = await db.get_token_stats(t.id) - - result.append({ - "id": t.id, - "st": t.st, # Session Token for editing - "at": t.at, # Access Token for editing (从ST转换而来) - "at_expires": t.at_expires.isoformat() if t.at_expires else None, # 🆕 AT过期时间 - "token": t.at, # 兼容前端 token.token 的访问方式 - "email": t.email, - "name": t.name, - "remark": t.remark, - "is_active": t.is_active, - "created_at": t.created_at.isoformat() if t.created_at else None, - "last_used_at": t.last_used_at.isoformat() if t.last_used_at else None, - "use_count": t.use_count, - "credits": t.credits, # 🆕 余额 - "user_paygate_tier": t.user_paygate_tier, - "current_project_id": t.current_project_id, # 🆕 项目ID - "current_project_name": t.current_project_name, # 🆕 项目名称 - "image_enabled": t.image_enabled, - "video_enabled": t.video_enabled, - "image_concurrency": t.image_concurrency, - "video_concurrency": t.video_concurrency, - "image_count": stats.image_count if stats else 0, - "video_count": stats.video_count if stats else 0, - "error_count": stats.error_count if stats else 0 - }) - - return result # 直接返回数组,兼容前端 + return [{ + "id": row.get("id"), + "st": row.get("st"), # Session Token for editing + "at": row.get("at"), # Access Token for editing (从ST转换而来) + "at_expires": to_iso(row.get("at_expires")) if row.get("at_expires") else None, # 🆕 AT过期时间 + "token": row.get("at"), # 兼容前端 token.token 的访问方式 + "email": row.get("email"), + "name": row.get("name"), + "remark": row.get("remark"), + "is_active": bool(row.get("is_active")), + "created_at": to_iso(row.get("created_at")) if row.get("created_at") else None, + "last_used_at": to_iso(row.get("last_used_at")) if row.get("last_used_at") else None, + "use_count": row.get("use_count"), + "credits": row.get("credits"), # 🆕 余额 + "user_paygate_tier": row.get("user_paygate_tier"), + "current_project_id": row.get("current_project_id"), # 🆕 项目ID + "current_project_name": row.get("current_project_name"), # 🆕 项目名称 + "image_enabled": bool(row.get("image_enabled")), + "video_enabled": bool(row.get("video_enabled")), + "image_concurrency": row.get("image_concurrency"), + "video_concurrency": row.get("video_concurrency"), + "image_count": row.get("image_count", 0), + "video_count": row.get("video_count", 0), + "error_count": row.get("error_count", 0) + } for row in token_rows] # 直接返回数组,兼容前端 @router.post("/api/tokens") @@ -653,6 +648,11 @@ async def import_tokens( added = 0 updated = 0 errors = [] + # 保持与历史逻辑一致:按 created_at DESC 的结果中,优先命中同邮箱“最新一条” + existing_by_email = {} + for existing_token in await token_manager.get_all_tokens(): + if existing_token.email and existing_token.email not in existing_by_email: + existing_by_email[existing_token.email] = existing_token for idx, item in enumerate(request.tokens): try: @@ -686,8 +686,7 @@ async def import_tokens( pass # 使用邮箱检查是否已存在 - existing_tokens = await token_manager.get_all_tokens() - existing = next((t for t in existing_tokens if t.email == email), None) + existing = existing_by_email.get(email) if existing: # 更新现有Token @@ -704,6 +703,14 @@ async def import_tokens( # 如果过期则禁用 if is_expired: await token_manager.disable_token(existing.id) + existing.is_active = False + existing.st = st + existing.at = at + existing.at_expires = at_expires + existing.image_enabled = item.image_enabled + existing.video_enabled = item.video_enabled + existing.image_concurrency = item.image_concurrency + existing.video_concurrency = item.video_concurrency updated += 1 else: # 添加新Token @@ -717,6 +724,8 @@ async def import_tokens( # 如果过期则禁用 if is_expired: await token_manager.disable_token(new_token.id) + new_token.is_active = False + existing_by_email[email] = new_token added += 1 except Exception as e: @@ -894,17 +903,14 @@ async def update_generation_config( @router.get("/api/system/info") async def get_system_info(token: str = Depends(verify_admin_token)): """Get system information""" - tokens = await token_manager.get_all_tokens() - active_tokens = [t for t in tokens if t.is_active] - - total_credits = sum(t.credits for t in active_tokens) + stats = await db.get_system_info_stats() return { "success": True, "info": { - "total_tokens": len(tokens), - "active_tokens": len(active_tokens), - "total_credits": total_credits, + "total_tokens": stats["total_tokens"], + "active_tokens": stats["active_tokens"], + "total_credits": stats["total_credits"], "version": "1.0.0" } } @@ -927,37 +933,7 @@ async def logout(token: str = Depends(verify_admin_token)): @router.get("/api/stats") async def get_stats(token: str = Depends(verify_admin_token)): """Get statistics for dashboard""" - tokens = await token_manager.get_all_tokens() - active_tokens = [t for t in tokens if t.is_active] - - # Calculate totals - total_images = 0 - total_videos = 0 - total_errors = 0 - today_images = 0 - today_videos = 0 - today_errors = 0 - - for t in tokens: - stats = await db.get_token_stats(t.id) - if stats: - total_images += stats.image_count - total_videos += stats.video_count - total_errors += stats.error_count # Historical total errors - today_images += stats.today_image_count - today_videos += stats.today_video_count - today_errors += stats.today_error_count - - return { - "total_tokens": len(tokens), - "active_tokens": len(active_tokens), - "total_images": total_images, - "total_videos": total_videos, - "total_errors": total_errors, - "today_images": today_images, - "today_videos": today_videos, - "today_errors": today_errors - } + return await db.get_dashboard_stats() @router.get("/api/logs") @@ -965,10 +941,33 @@ async def get_logs( limit: int = 100, token: str = Depends(verify_admin_token) ): - """Get request logs with token email""" - logs = await db.get_logs(limit=limit) + """Get lightweight request logs for list view""" + limit = max(1, min(limit, 100)) + logs = await db.get_logs(limit=limit, include_payload=False) return [{ + "id": log.get("id"), + "token_id": log.get("token_id"), + "token_email": log.get("token_email"), + "token_username": log.get("token_username"), + "operation": log.get("operation"), + "status_code": log.get("status_code"), + "duration": log.get("duration"), + "created_at": log.get("created_at") + } for log in logs] + + +@router.get("/api/logs/{log_id}") +async def get_log_detail( + log_id: int, + token: str = Depends(verify_admin_token) +): + """Get single request log detail (payload loaded on demand)""" + log = await db.get_log_detail(log_id) + if not log: + raise HTTPException(status_code=404, detail="日志不存在") + + return { "id": log.get("id"), "token_id": log.get("token_id"), "token_email": log.get("token_email"), @@ -979,7 +978,7 @@ async def get_logs( "created_at": log.get("created_at"), "request_body": log.get("request_body"), "response_body": log.get("response_body") - } for log in logs] + } @router.delete("/api/logs") diff --git a/src/core/database.py b/src/core/database.py index a26f584..2f2cdde 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -2,7 +2,7 @@ import aiosqlite import json from datetime import datetime -from typing import Optional, List +from typing import Optional, List, Dict, Any from pathlib import Path from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, Project, CaptchaConfig, PluginConfig @@ -577,10 +577,19 @@ class Database: await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_token_st ON tokens(st)") await db.execute("CREATE INDEX IF NOT EXISTS idx_project_id ON projects(project_id)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_tokens_email ON tokens(email)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_tokens_is_active_last_used_at ON tokens(is_active, last_used_at)") # Migrate request_logs table if needed await self._migrate_request_logs(db) + # Request logs query indexes (列表按 created_at 排序 / token 过滤) + await db.execute("CREATE INDEX IF NOT EXISTS idx_request_logs_created_at ON request_logs(created_at DESC)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_request_logs_token_id_created_at ON request_logs(token_id, created_at DESC)") + + # Token stats lookup index + await db.execute("CREATE INDEX IF NOT EXISTS idx_token_stats_token_id ON token_stats(token_id)") + await db.commit() async def _migrate_request_logs(self, db): @@ -700,6 +709,81 @@ class Database: rows = await cursor.fetchall() return [Token(**dict(row)) for row in rows] + async def get_all_tokens_with_stats(self) -> List[Dict[str, Any]]: + """Get all tokens with merged statistics in one query""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(""" + SELECT + t.*, + COALESCE(ts.image_count, 0) AS image_count, + COALESCE(ts.video_count, 0) AS video_count, + COALESCE(ts.error_count, 0) AS error_count + FROM tokens t + LEFT JOIN token_stats ts ON ts.token_id = t.id + ORDER BY t.created_at DESC + """) + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + async def get_dashboard_stats(self) -> Dict[str, int]: + """Get dashboard counters with aggregated SQL queries""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + token_cursor = await db.execute(""" + SELECT + COUNT(*) AS total_tokens, + COALESCE(SUM(CASE WHEN is_active = 1 THEN 1 ELSE 0 END), 0) AS active_tokens + FROM tokens + """) + token_row = await token_cursor.fetchone() + + stats_cursor = await db.execute(""" + SELECT + COALESCE(SUM(image_count), 0) AS total_images, + COALESCE(SUM(video_count), 0) AS total_videos, + COALESCE(SUM(error_count), 0) AS total_errors, + COALESCE(SUM(today_image_count), 0) AS today_images, + COALESCE(SUM(today_video_count), 0) AS today_videos, + COALESCE(SUM(today_error_count), 0) AS today_errors + FROM token_stats + """) + stats_row = await stats_cursor.fetchone() + + token_data = dict(token_row) if token_row else {} + stats_data = dict(stats_row) if stats_row else {} + + return { + "total_tokens": int(token_data.get("total_tokens") or 0), + "active_tokens": int(token_data.get("active_tokens") or 0), + "total_images": int(stats_data.get("total_images") or 0), + "total_videos": int(stats_data.get("total_videos") or 0), + "total_errors": int(stats_data.get("total_errors") or 0), + "today_images": int(stats_data.get("today_images") or 0), + "today_videos": int(stats_data.get("today_videos") or 0), + "today_errors": int(stats_data.get("today_errors") or 0) + } + + async def get_system_info_stats(self) -> Dict[str, int]: + """Get lightweight system counters used by admin dashboard""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(""" + SELECT + COUNT(*) AS total_tokens, + COALESCE(SUM(CASE WHEN is_active = 1 THEN 1 ELSE 0 END), 0) AS active_tokens, + COALESCE(SUM(CASE WHEN is_active = 1 THEN credits ELSE 0 END), 0) AS total_credits + FROM tokens + """) + row = await cursor.fetchone() + data = dict(row) if row else {} + return { + "total_tokens": int(data.get("total_tokens") or 0), + "active_tokens": int(data.get("active_tokens") or 0), + "total_credits": int(data.get("total_credits") or 0) + } + async def get_active_tokens(self) -> List[Token]: """Get all active tokens""" async with aiosqlite.connect(self.db_path) as db: @@ -1062,19 +1146,19 @@ class Database: log.status_code, log.duration)) await db.commit() - async def get_logs(self, limit: int = 100, token_id: Optional[int] = None): - """Get request logs with token email""" + async def get_logs(self, limit: int = 100, token_id: Optional[int] = None, include_payload: bool = False): + """Get request logs with token info, optionally including payload fields""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row + payload_columns = "rl.request_body, rl.response_body," if include_payload else "" if token_id: - cursor = await db.execute(""" + cursor = await db.execute(f""" SELECT rl.id, rl.token_id, rl.operation, - rl.request_body, - rl.response_body, + {payload_columns} rl.status_code, rl.duration, rl.created_at, @@ -1087,13 +1171,12 @@ class Database: LIMIT ? """, (token_id, limit)) else: - cursor = await db.execute(""" + cursor = await db.execute(f""" SELECT rl.id, rl.token_id, rl.operation, - rl.request_body, - rl.response_body, + {payload_columns} rl.status_code, rl.duration, rl.created_at, @@ -1108,6 +1191,30 @@ class Database: rows = await cursor.fetchall() return [dict(row) for row in rows] + async def get_log_detail(self, log_id: int) -> Optional[Dict[str, Any]]: + """Get single request log detail including payload fields""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(""" + SELECT + rl.id, + rl.token_id, + rl.operation, + rl.request_body, + rl.response_body, + rl.status_code, + rl.duration, + rl.created_at, + t.email as token_email, + t.name as token_username + FROM request_logs rl + LEFT JOIN tokens t ON rl.token_id = t.id + WHERE rl.id = ? + LIMIT 1 + """, (log_id,)) + row = await cursor.fetchone() + return dict(row) if row else None + async def clear_all_logs(self): """Clear all request logs""" async with aiosqlite.connect(self.db_path) as db: diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 9f44075..5f076e2 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -793,7 +793,8 @@ class GenerationHandler: if stream: yield self._create_stream_chunk("初始化生成环境...\n") - if not await self.token_manager.is_at_valid(token.id): + token = await self.token_manager.ensure_valid_token(token) + if not token: error_msg = "Token AT无效或刷新失败" debug_logger.log_error(f"[GENERATION] {error_msg}") if stream: @@ -801,9 +802,6 @@ class GenerationHandler: yield self._create_error_response(error_msg) return - # 重新获取token (AT可能已刷新) - token = await self.token_manager.get_token(token.id) - # 4. 确保Project存在 debug_logger.log_info(f"[GENERATION] 检查/创建Project...") diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py index 78d9795..5629743 100644 --- a/src/services/load_balancer.py +++ b/src/services/load_balancer.py @@ -144,9 +144,11 @@ class LoadBalancer: # 只为候选列表中真正尝试到的 token 做 AT 校验,避免每次请求把所有 token 全扫一遍 for item in available_tokens: token = item["token"] + token_id = token.id - if not await self.token_manager.is_at_valid(token.id): - debug_logger.log_info(f"[LOAD_BALANCER] 跳过 Token {token.id}: AT无效或已过期") + token = await self.token_manager.ensure_valid_token(token) + if not token: + debug_logger.log_info(f"[LOAD_BALANCER] 跳过 Token {token_id}: AT无效或已过期") continue if reserve and not await self._reserve_slot(token.id, for_image_generation, for_video_generation): diff --git a/src/services/token_manager.py b/src/services/token_manager.py index e5e40a0..40116f6 100644 --- a/src/services/token_manager.py +++ b/src/services/token_manager.py @@ -230,43 +230,58 @@ class TokenManager: # ========== AT自动刷新逻辑 (核心) ========== - async def is_at_valid(self, token_id: int) -> bool: - """检查AT是否有效,如果无效或即将过期则自动刷新 - - Returns: - True if AT is valid or refreshed successfully - False if AT cannot be refreshed - """ - token = await self.db.get_token(token_id) - if not token: - return False - - # 如果AT不存在,需要刷新 + def _should_refresh_at(self, token: Token) -> bool: + """根据当前 token 快照判断是否需要刷新 AT。""" if not token.at: - debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新") - return await self._refresh_at(token_id) + debug_logger.log_info(f"[AT_CHECK] Token {token.id}: AT不存在,需要刷新") + return True - # 如果没有过期时间,假设需要刷新 if not token.at_expires: - debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新") - return await self._refresh_at(token_id) + debug_logger.log_info(f"[AT_CHECK] Token {token.id}: AT过期时间未知,尝试刷新") + return True - # 检查是否即将过期 (提前1小时刷新) now = datetime.now(timezone.utc) - # 确保at_expires也是timezone-aware if token.at_expires.tzinfo is None: at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) else: at_expires_aware = token.at_expires time_until_expiry = at_expires_aware - now + if time_until_expiry.total_seconds() < 3600: + debug_logger.log_info( + f"[AT_CHECK] Token {token.id}: AT即将过期 " + f"(剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新" + ) + return True - if time_until_expiry.total_seconds() < 3600: # 1 hour (3600 seconds) - debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新") - return await self._refresh_at(token_id) + return False - # AT有效 - return True + async def ensure_valid_token(self, token: Optional[Token]) -> Optional[Token]: + """确保 token 的 AT 可用,并在必要时返回刷新后的最新对象。""" + if not token: + return None + + if not self._should_refresh_at(token): + return token + + if not await self._refresh_at(token.id): + return None + + return await self.db.get_token(token.id) + + async def is_at_valid(self, token_id: int, token: Optional[Token] = None) -> bool: + """检查AT是否有效,如果无效或即将过期则自动刷新 + + Returns: + True if AT is valid or refreshed successfully + False if AT cannot be refreshed + """ + token_obj = token if token and token.id == token_id else await self.db.get_token(token_id) + if not token_obj: + return False + + valid_token = await self.ensure_valid_token(token_obj) + return valid_token is not None async def _refresh_at(self, token_id: int) -> bool: @@ -572,12 +587,10 @@ class TokenManager: return 0 # 确保AT有效 - if not await self.is_at_valid(token_id): + token = await self.ensure_valid_token(token) + if not token: return 0 - # 重新获取token (AT可能已刷新) - token = await self.db.get_token(token_id) - try: result = await self.flow_client.get_credits(token.at) credits = result.get("credits", 0) diff --git a/static/manage.html b/static/manage.html index cec1e18..2d5a02a 100644 --- a/static/manage.html +++ b/static/manage.html @@ -785,11 +785,11 @@ generateRandomToken=()=>{const chars='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';let token='';for(let i=0;i<32;i++){token+=chars.charAt(Math.floor(Math.random()*chars.length))}$('cfgPluginConnectionToken').value=token;showToast('随机Token已生成','success')}, toggleATAutoRefresh=async()=>{try{const enabled=$('atAutoRefreshToggle').checked;const r=await apiRequest('/api/token-refresh/enabled',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r){$('atAutoRefreshToggle').checked=!enabled;return}const d=await r.json();if(d.success){showToast(enabled?'AT自动刷新已启用':'AT自动刷新已禁用','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('atAutoRefreshToggle').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('atAutoRefreshToggle').checked=!enabled}}, loadATAutoRefreshConfig=async()=>{try{const r=await apiRequest('/api/token-refresh/config');if(!r)return;const d=await r.json();if(d.success&&d.config){$('atAutoRefreshToggle').checked=d.config.at_auto_refresh_enabled||false}else{console.error('AT自动刷新配置数据格式错误:',d)}}catch(e){console.error('加载AT自动刷新配置失败:',e)}}, - loadLogs=async()=>{try{const r=await apiRequest('/api/logs?limit=100');if(!r)return;const logs=await r.json();window.allLogs=logs;const tb=$('logsTableBody');tb.innerHTML=logs.map(l=>`
${label}: data URL(长度 ${String(url).length})
`; + } const safeUrl=escapeLogHtml(url); return `${label}: ${safeUrl}
`; } @@ -826,25 +861,38 @@ function renderMediaPreview(label,url,withUrl=true){ if(!url) return ''; - const safeUrl=escapeLogHtml(url); - let previewHtml=''; - if(isVideoUrl(url)){ - previewHtml=``; - }else if(isImageUrl(url)){ - previewHtml=`${escapeLogHtml(label)}
${withUrl?renderLogLink('URL',url):''}${previewHtml}${escapeLogHtml(label)}
${withUrl?renderLogLink('URL',url):''}${previewTrigger}${escapeLogHtml(requestBodyObj?JSON.stringify(requestBodyObj,null,2):(log.request_body||'无'))}${escapeLogHtml(requestPayloadText)}放大分辨率: ${escapeLogHtml(upResolution)}
`; if(upPreviewUrl){ assetsHtml+=renderMediaPreview(`${upResolution}结果`,upPreviewUrl,false); @@ -879,21 +927,20 @@ detailHtml+=`无资产详情
'}${escapeLogHtml(JSON.stringify(responseBodyObj,null,2))}${escapeLogHtml(responsePayloadText)}${escapeLogHtml(log.response_body||'无')}${escapeLogHtml(responsePayloadText)}${escapeLogHtml(responseBodyObj.error.message||responseBodyObj.error||'未知错误')}
${escapeLogHtml(responseBodyObj?JSON.stringify(responseBodyObj,null,2):(log.response_body||'无'))}${escapeLogHtml(responsePayloadText)}