From fa2abd560abaf7c9b8eb5aa320b229fffbf19e86 Mon Sep 17 00:00:00 2001 From: "yuechenglong.5" Date: Tue, 20 Jan 2026 10:17:39 +0800 Subject: [PATCH 1/2] =?UTF-8?q?chore:=20cherry-pick=20=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=92=8C=E5=88=A0=E9=99=A4=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - docs: 添加 Kiro OAuth web 认证端点说明 (ace7c0c) - chore: 删除包含敏感数据的测试文件 (8f06f6a) - 保留本地修改: refresh_manager, token_repository 等 --- README.md | 17 +- README_CN.md | 17 +- cmd/server/main.go | 8 + internal/auth/kiro/oauth_web.go | 1 + internal/auth/kiro/refresh_manager.go | 145 +++++++ internal/auth/kiro/token.go | 6 +- internal/auth/kiro/token_repository.go | 273 +++++++++++++ internal/runtime/executor/kiro_executor.go | 7 + test_api.py | 452 --------------------- test_auth_diff.go | 273 ------------- test_auth_idc_go1.go | 323 --------------- test_auth_js_style.go | 237 ----------- test_kiro_debug.go | 348 ---------------- test_proxy_debug.go | 367 ----------------- 14 files changed, 469 insertions(+), 2005 deletions(-) create mode 100644 internal/auth/kiro/refresh_manager.go create mode 100644 internal/auth/kiro/token_repository.go delete mode 100644 test_api.py delete mode 100644 test_auth_diff.go delete mode 100644 test_auth_idc_go1.go delete mode 100644 test_auth_js_style.go delete mode 100644 test_kiro_debug.go delete mode 100644 test_proxy_debug.go diff --git a/README.md b/README.md index 1555e643..092a3214 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The Plus release stays in lockstep with the mainline features. - **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI - **Rate Limiter**: Built-in request rate limiting to prevent API abuse -- **Background Token Refresh**: Automatic token refresh in background to avoid expiration +- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration - **Metrics & Monitoring**: Request metrics collection for monitoring and debugging - **Device Fingerprint**: Device fingerprint generation for enhanced security - **Cooldown Management**: Smart cooldown mechanism for API rate limits @@ -25,6 +25,21 @@ The Plus release stays in lockstep with the mainline features. - **Model Converter**: Unified model name conversion across providers - **UTF-8 Stream Processing**: Improved streaming response handling +## Kiro Authentication + +### Web-based OAuth Login + +Access the Kiro OAuth web interface at: + +``` +http://your-server:8080/v0/oauth/kiro +``` + +This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with: +- AWS Builder ID login +- AWS Identity Center (IDC) login +- Token import from Kiro IDE + ## Quick Deployment with Docker ### One-Command Deployment diff --git a/README_CN.md b/README_CN.md index 6ac2e483..b5b4d5f9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -17,7 +17,7 @@ - **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI - **请求限流器**: 内置请求限流,防止 API 滥用 -- **后台令牌刷新**: 自动后台刷新令牌,避免过期 +- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌 - **监控指标**: 请求指标收集,用于监控和调试 - **设备指纹**: 设备指纹生成,增强安全性 - **冷却管理**: 智能冷却机制,应对 API 速率限制 @@ -25,6 +25,21 @@ - **模型转换器**: 跨供应商的统一模型名称转换 - **UTF-8 流处理**: 改进的流式响应处理 +## Kiro 认证 + +### 网页端 OAuth 登录 + +访问 Kiro OAuth 网页认证界面: + +``` +http://your-server:8080/v0/oauth/kiro +``` + +提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持: +- AWS Builder ID 登录 +- AWS Identity Center (IDC) 登录 +- 从 Kiro IDE 导入令牌 + ## Docker 快速部署 ### 一键部署 diff --git a/cmd/server/main.go b/cmd/server/main.go index 8148ceee..d0f70f67 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -17,6 +17,7 @@ import ( "github.com/joho/godotenv" configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo" "github.com/router-for-me/CLIProxyAPI/v6/internal/cmd" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -533,6 +534,13 @@ func main() { } // Start the main proxy service managementasset.StartAutoUpdater(context.Background(), configFilePath) + + // 初始化并启动 Kiro token 后台刷新 + if cfg.AuthDir != "" { + kiro.InitializeAndStart(cfg.AuthDir, cfg) + defer kiro.StopGlobalRefreshManager() + } + cmd.StartService(cfg, configFilePath, password) } } diff --git a/internal/auth/kiro/oauth_web.go b/internal/auth/kiro/oauth_web.go index 13198516..4ffbb7fd 100644 --- a/internal/auth/kiro/oauth_web.go +++ b/internal/auth/kiro/oauth_web.go @@ -385,6 +385,7 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess ClientSecret: session.clientSecret, Email: email, Region: session.region, + StartURL: session.startURL, } h.mu.Lock() diff --git a/internal/auth/kiro/refresh_manager.go b/internal/auth/kiro/refresh_manager.go new file mode 100644 index 00000000..cd27b432 --- /dev/null +++ b/internal/auth/kiro/refresh_manager.go @@ -0,0 +1,145 @@ +package kiro + +import ( + "context" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + log "github.com/sirupsen/logrus" +) + +// RefreshManager 是后台刷新器的单例管理器 +type RefreshManager struct { + mu sync.Mutex + refresher *BackgroundRefresher + ctx context.Context + cancel context.CancelFunc + started bool +} + +var ( + globalRefreshManager *RefreshManager + managerOnce sync.Once +) + +// GetRefreshManager 获取全局刷新管理器实例 +func GetRefreshManager() *RefreshManager { + managerOnce.Do(func() { + globalRefreshManager = &RefreshManager{} + }) + return globalRefreshManager +} + +// Initialize 初始化后台刷新器 +// baseDir: token 文件所在的目录 +// cfg: 应用配置 +func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.started { + log.Debug("refresh manager: already initialized") + return nil + } + + if baseDir == "" { + log.Warn("refresh manager: base directory not provided, skipping initialization") + return nil + } + + // 创建 token 存储库 + repo := NewFileTokenRepository(baseDir) + + // 创建后台刷新器,配置参数 + m.refresher = NewBackgroundRefresher( + repo, + WithInterval(time.Minute), // 每分钟检查一次 + WithBatchSize(50), // 每批最多处理 50 个 token + WithConcurrency(10), // 最多 10 个并发刷新 + WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 + ) + + log.Infof("refresh manager: initialized with base directory %s", baseDir) + return nil +} + +// Start 启动后台刷新 +func (m *RefreshManager) Start() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.started { + log.Debug("refresh manager: already started") + return + } + + if m.refresher == nil { + log.Warn("refresh manager: not initialized, cannot start") + return + } + + m.ctx, m.cancel = context.WithCancel(context.Background()) + m.refresher.Start(m.ctx) + m.started = true + + log.Info("refresh manager: background refresh started") +} + +// Stop 停止后台刷新 +func (m *RefreshManager) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.started { + return + } + + if m.cancel != nil { + m.cancel() + } + + if m.refresher != nil { + m.refresher.Stop() + } + + m.started = false + log.Info("refresh manager: background refresh stopped") +} + +// IsRunning 检查后台刷新是否正在运行 +func (m *RefreshManager) IsRunning() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.started +} + +// UpdateBaseDir 更新 token 目录(用于运行时配置更改) +func (m *RefreshManager) UpdateBaseDir(baseDir string) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.refresher != nil && m.refresher.tokenRepo != nil { + if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok { + repo.SetBaseDir(baseDir) + log.Infof("refresh manager: updated base directory to %s", baseDir) + } + } +} + +// InitializeAndStart 初始化并启动后台刷新(便捷方法) +func InitializeAndStart(baseDir string, cfg *config.Config) { + manager := GetRefreshManager() + if err := manager.Initialize(baseDir, cfg); err != nil { + log.Errorf("refresh manager: initialization failed: %v", err) + return + } + manager.Start() +} + +// StopGlobalRefreshManager 停止全局刷新管理器 +func StopGlobalRefreshManager() { + if globalRefreshManager != nil { + globalRefreshManager.Stop() + } +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go index bfbdc795..0484a2dc 100644 --- a/internal/auth/kiro/token.go +++ b/internal/auth/kiro/token.go @@ -26,13 +26,13 @@ type KiroTokenStorage struct { // LastRefresh is the timestamp of the last token refresh LastRefresh string `json:"last_refresh"` // ClientID is the OAuth client ID (required for token refresh) - ClientID string `json:"clientId,omitempty"` + ClientID string `json:"client_id,omitempty"` // ClientSecret is the OAuth client secret (required for token refresh) - ClientSecret string `json:"clientSecret,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` // Region is the AWS region Region string `json:"region,omitempty"` // StartURL is the AWS Identity Center start URL (for IDC auth) - StartURL string `json:"startUrl,omitempty"` + StartURL string `json:"start_url,omitempty"` // Email is the user's email address Email string `json:"email,omitempty"` } diff --git a/internal/auth/kiro/token_repository.go b/internal/auth/kiro/token_repository.go new file mode 100644 index 00000000..f7ed76a8 --- /dev/null +++ b/internal/auth/kiro/token_repository.go @@ -0,0 +1,273 @@ +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储 +type FileTokenRepository struct { + mu sync.RWMutex + baseDir string +} + +// NewFileTokenRepository 创建一个新的文件 token 存储库 +func NewFileTokenRepository(baseDir string) *FileTokenRepository { + return &FileTokenRepository{ + baseDir: baseDir, + } +} + +// SetBaseDir 设置基础目录 +func (r *FileTokenRepository) SetBaseDir(dir string) { + r.mu.Lock() + r.baseDir = strings.TrimSpace(dir) + r.mu.Unlock() +} + +// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序) +func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token { + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + log.Debug("token repository: base directory not configured") + return nil + } + + var tokens []*Token + + err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil // 忽略错误,继续遍历 + } + if d.IsDir() { + return nil + } + if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") { + return nil + } + + // 只处理 kiro 相关的 token 文件 + if !strings.HasPrefix(d.Name(), "kiro-") { + return nil + } + + token, err := r.readTokenFile(path) + if err != nil { + log.Debugf("token repository: failed to read token file %s: %v", path, err) + return nil + } + + if token != nil && token.RefreshToken != "" { + // 检查 token 是否需要刷新(过期前 5 分钟) + if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute { + tokens = append(tokens, token) + } + } + + return nil + }) + + if err != nil { + log.Warnf("token repository: error walking directory: %v", err) + } + + // 按最后验证时间排序(最旧的优先) + sort.Slice(tokens, func(i, j int) bool { + return tokens[i].LastVerified.Before(tokens[j].LastVerified) + }) + + // 限制返回数量 + if limit > 0 && len(tokens) > limit { + tokens = tokens[:limit] + } + + return tokens +} + +// UpdateToken 更新 token 并持久化到文件 +func (r *FileTokenRepository) UpdateToken(token *Token) error { + if token == nil { + return fmt.Errorf("token repository: token is nil") + } + + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + return fmt.Errorf("token repository: base directory not configured") + } + + // 构建文件路径 + filePath := filepath.Join(baseDir, token.ID) + if !strings.HasSuffix(filePath, ".json") { + filePath += ".json" + } + + // 读取现有文件内容 + existingData := make(map[string]any) + if data, err := os.ReadFile(filePath); err == nil { + _ = json.Unmarshal(data, &existingData) + } + + // 更新字段 + existingData["access_token"] = token.AccessToken + existingData["refresh_token"] = token.RefreshToken + existingData["last_refresh"] = time.Now().Format(time.RFC3339) + + if !token.ExpiresAt.IsZero() { + existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339) + } + + // 保持原有的关键字段 + if token.ClientID != "" { + existingData["client_id"] = token.ClientID + } + if token.ClientSecret != "" { + existingData["client_secret"] = token.ClientSecret + } + if token.AuthMethod != "" { + existingData["auth_method"] = token.AuthMethod + } + if token.Region != "" { + existingData["region"] = token.Region + } + if token.StartURL != "" { + existingData["start_url"] = token.StartURL + } + + // 序列化并写入文件 + raw, err := json.MarshalIndent(existingData, "", " ") + if err != nil { + return fmt.Errorf("token repository: marshal failed: %w", err) + } + + // 原子写入:先写入临时文件,再重命名 + tmpPath := filePath + ".tmp" + if err := os.WriteFile(tmpPath, raw, 0o600); err != nil { + return fmt.Errorf("token repository: write temp file failed: %w", err) + } + if err := os.Rename(tmpPath, filePath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("token repository: rename failed: %w", err) + } + + log.Debugf("token repository: updated token %s", token.ID) + return nil +} + +// readTokenFile 从文件读取 token +func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var metadata map[string]any + if err := json.Unmarshal(data, &metadata); err != nil { + return nil, err + } + + // 检查是否是 kiro token + tokenType, _ := metadata["type"].(string) + if tokenType != "kiro" { + return nil, nil + } + + // 检查 auth_method + authMethod, _ := metadata["auth_method"].(string) + if authMethod != "idc" && authMethod != "builder-id" { + return nil, nil // 只处理 IDC 和 Builder ID token + } + + token := &Token{ + ID: filepath.Base(path), + AuthMethod: authMethod, + } + + // 解析各字段 + if v, ok := metadata["access_token"].(string); ok { + token.AccessToken = v + } + if v, ok := metadata["refresh_token"].(string); ok { + token.RefreshToken = v + } + if v, ok := metadata["client_id"].(string); ok { + token.ClientID = v + } + if v, ok := metadata["client_secret"].(string); ok { + token.ClientSecret = v + } + if v, ok := metadata["region"].(string); ok { + token.Region = v + } + if v, ok := metadata["start_url"].(string); ok { + token.StartURL = v + } + if v, ok := metadata["provider"].(string); ok { + token.Provider = v + } + + // 解析时间字段 + if v, ok := metadata["expires_at"].(string); ok { + if t, err := time.Parse(time.RFC3339, v); err == nil { + token.ExpiresAt = t + } + } + if v, ok := metadata["last_refresh"].(string); ok { + if t, err := time.Parse(time.RFC3339, v); err == nil { + token.LastVerified = t + } + } + + return token, nil +} + +// ListKiroTokens 列出所有 Kiro token(用于调试) +func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) { + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + return nil, fmt.Errorf("token repository: base directory not configured") + } + + var tokens []*Token + + err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil + } + if d.IsDir() { + return nil + } + if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") { + return nil + } + + token, err := r.readTokenFile(path) + if err != nil { + return nil + } + if token != nil { + tokens = append(tokens, token) + } + return nil + }) + + return tokens, err +} diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index b0c14c61..b842d5c8 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -3617,6 +3617,13 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c if tokenData.ClientSecret != "" { updated.Metadata["client_secret"] = tokenData.ClientSecret } + // Preserve region and start_url for IDC token refresh + if tokenData.Region != "" { + updated.Metadata["region"] = tokenData.Region + } + if tokenData.StartURL != "" { + updated.Metadata["start_url"] = tokenData.StartURL + } if updated.Attributes == nil { updated.Attributes = make(map[string]string) diff --git a/test_api.py b/test_api.py deleted file mode 100644 index 1849e2ba..00000000 --- a/test_api.py +++ /dev/null @@ -1,452 +0,0 @@ -#!/usr/bin/env python3 -""" -CLIProxyAPI 全面测试脚本 -测试模型列表、流式输出、thinking模式及复杂任务 -""" - -import requests -import json -import time -import sys -import io -from typing import Optional, List, Dict, Any - -# 修复 Windows 控制台编码问题 -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') -sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') - -# 配置 -BASE_URL = "http://localhost:8317" -API_KEY = "your-api-key-1" -HEADERS = { - "Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json" -} - -# 复杂任务提示词 - 用于测试 thinking 模式 -COMPLEX_TASK_PROMPT = """请帮我分析以下复杂的编程问题,并给出详细的解决方案: - -问题:设计一个高并发的分布式任务调度系统,需要满足以下要求: -1. 支持百万级任务队列 -2. 任务可以设置优先级、延迟执行、定时执行 -3. 支持任务依赖关系(DAG调度) -4. 失败重试机制,支持指数退避 -5. 任务结果持久化和查询 -6. 水平扩展能力 -7. 监控和告警 - -请从以下几个方面详细分析: -1. 整体架构设计 -2. 核心数据结构 -3. 调度算法选择 -4. 容错机制设计 -5. 性能优化策略 -6. 技术选型建议 - -请逐步思考每个方面,给出你的推理过程。""" - -# 简单测试提示词 -SIMPLE_PROMPT = "Hello! Please respond with 'OK' if you receive this message." - -def print_separator(title: str): - print(f"\n{'='*60}") - print(f" {title}") - print(f"{'='*60}\n") - -def print_result(name: str, success: bool, detail: str = ""): - status = "✅ PASS" if success else "❌ FAIL" - print(f"{status} | {name}") - if detail: - print(f" └─ {detail[:200]}{'...' if len(detail) > 200 else ''}") - -def get_models() -> List[str]: - """获取可用模型列表""" - print_separator("获取模型列表") - try: - resp = requests.get(f"{BASE_URL}/v1/models", headers=HEADERS, timeout=30) - if resp.status_code == 200: - data = resp.json() - models = [m.get("id", m.get("name", "unknown")) for m in data.get("data", [])] - print(f"找到 {len(models)} 个模型:") - for m in models: - print(f" - {m}") - return models - else: - print(f"❌ 获取模型列表失败: HTTP {resp.status_code}") - print(f" 响应: {resp.text[:500]}") - return [] - except Exception as e: - print(f"❌ 获取模型列表异常: {e}") - return [] - -def test_model_basic(model: str) -> tuple: - """基础可用性测试,返回 (success, error_detail)""" - try: - payload = { - "model": model, - "messages": [{"role": "user", "content": SIMPLE_PROMPT}], - "max_tokens": 50, - "stream": False - } - resp = requests.post( - f"{BASE_URL}/v1/chat/completions", - headers=HEADERS, - json=payload, - timeout=60 - ) - if resp.status_code == 200: - data = resp.json() - content = data.get("choices", [{}])[0].get("message", {}).get("content", "") - return (bool(content), f"content_len={len(content)}") - else: - return (False, f"HTTP {resp.status_code}: {resp.text[:300]}") - except Exception as e: - return (False, str(e)) - -def test_streaming(model: str) -> Dict[str, Any]: - """测试流式输出""" - result = {"success": False, "chunks": 0, "content": "", "error": None} - try: - payload = { - "model": model, - "messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}], - "max_tokens": 100, - "stream": True - } - resp = requests.post( - f"{BASE_URL}/v1/chat/completions", - headers=HEADERS, - json=payload, - timeout=60, - stream=True - ) - - if resp.status_code != 200: - result["error"] = f"HTTP {resp.status_code}: {resp.text[:200]}" - return result - - content_parts = [] - for line in resp.iter_lines(): - if line: - line_str = line.decode('utf-8') - if line_str.startswith("data: "): - data_str = line_str[6:] - if data_str.strip() == "[DONE]": - break - try: - data = json.loads(data_str) - result["chunks"] += 1 - choices = data.get("choices", []) - if choices: - delta = choices[0].get("delta", {}) - if "content" in delta and delta["content"]: - content_parts.append(delta["content"]) - except json.JSONDecodeError: - pass - except Exception as e: - result["error"] = f"Parse error: {e}, data: {data_str[:200]}" - - result["content"] = "".join(content_parts) - result["success"] = result["chunks"] > 0 and len(result["content"]) > 0 - - except Exception as e: - result["error"] = str(e) - - return result - -def test_thinking_mode(model: str, complex_task: bool = False) -> Dict[str, Any]: - """测试 thinking 模式""" - result = { - "success": False, - "has_reasoning": False, - "reasoning_content": "", - "content": "", - "error": None, - "chunks": 0 - } - - prompt = COMPLEX_TASK_PROMPT if complex_task else "What is 15 * 23? Please think step by step." - - try: - # 尝试不同的 thinking 模式参数格式 - payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": 8000 if complex_task else 2000, - "stream": True - } - - # 根据模型类型添加 thinking 参数 - if "claude" in model.lower(): - payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000} - elif "gemini" in model.lower(): - payload["thinking"] = {"thinking_budget": 5000 if complex_task else 2000} - elif "gpt" in model.lower() or "codex" in model.lower() or "o1" in model.lower() or "o3" in model.lower(): - payload["reasoning_effort"] = "high" if complex_task else "medium" - else: - # 通用格式 - payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000} - - resp = requests.post( - f"{BASE_URL}/v1/chat/completions", - headers=HEADERS, - json=payload, - timeout=300 if complex_task else 120, - stream=True - ) - - if resp.status_code != 200: - result["error"] = f"HTTP {resp.status_code}: {resp.text[:500]}" - return result - - content_parts = [] - reasoning_parts = [] - - for line in resp.iter_lines(): - if line: - line_str = line.decode('utf-8') - if line_str.startswith("data: "): - data_str = line_str[6:] - if data_str.strip() == "[DONE]": - break - try: - data = json.loads(data_str) - result["chunks"] += 1 - - choices = data.get("choices", []) - if not choices: - continue - choice = choices[0] - delta = choice.get("delta", {}) - - # 检查 reasoning_content (Claude/OpenAI格式) - if "reasoning_content" in delta and delta["reasoning_content"]: - reasoning_parts.append(delta["reasoning_content"]) - result["has_reasoning"] = True - - # 检查 thinking (Gemini格式) - if "thinking" in delta and delta["thinking"]: - reasoning_parts.append(delta["thinking"]) - result["has_reasoning"] = True - - # 常规内容 - if "content" in delta and delta["content"]: - content_parts.append(delta["content"]) - - except json.JSONDecodeError as e: - pass - except Exception as e: - result["error"] = f"Parse error: {e}" - - result["reasoning_content"] = "".join(reasoning_parts) - result["content"] = "".join(content_parts) - result["success"] = result["chunks"] > 0 and (len(result["content"]) > 0 or len(result["reasoning_content"]) > 0) - - except requests.exceptions.Timeout: - result["error"] = "Request timeout" - except Exception as e: - result["error"] = str(e) - - return result - -def run_full_test(): - """运行完整测试""" - print("\n" + "="*60) - print(" CLIProxyAPI 全面测试") - print("="*60) - print(f"目标地址: {BASE_URL}") - print(f"API Key: {API_KEY[:10]}...") - - # 1. 获取模型列表 - models = get_models() - if not models: - print("\n❌ 无法获取模型列表,测试终止") - return - - # 2. 基础可用性测试 - print_separator("基础可用性测试") - available_models = [] - for model in models: - success, detail = test_model_basic(model) - print_result(f"模型: {model}", success, detail) - if success: - available_models.append(model) - - print(f"\n可用模型: {len(available_models)}/{len(models)}") - - if not available_models: - print("\n❌ 没有可用的模型,测试终止") - return - - # 3. 流式输出测试 - print_separator("流式输出测试") - streaming_results = {} - for model in available_models: - result = test_streaming(model) - streaming_results[model] = result - detail = f"chunks={result['chunks']}, content_len={len(result['content'])}" - if result["error"]: - detail = f"error: {result['error']}" - print_result(f"模型: {model}", result["success"], detail) - - # 4. Thinking 模式测试 (简单任务) - print_separator("Thinking 模式测试 (简单任务)") - thinking_results = {} - for model in available_models: - result = test_thinking_mode(model, complex_task=False) - thinking_results[model] = result - detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}" - if result["error"]: - detail = f"error: {result['error']}" - print_result(f"模型: {model}", result["success"], detail) - - # 5. Thinking 模式测试 (复杂任务) - 只测试支持 thinking 的模型 - print_separator("Thinking 模式测试 (复杂任务)") - complex_thinking_results = {} - - # 选择前3个可用模型进行复杂任务测试 - test_models = available_models[:3] - print(f"测试模型 (取前3个): {test_models}\n") - - for model in test_models: - print(f"⏳ 正在测试 {model} (复杂任务,可能需要较长时间)...") - result = test_thinking_mode(model, complex_task=True) - complex_thinking_results[model] = result - - if result["success"]: - detail = f"reasoning={result['has_reasoning']}, reasoning_len={len(result['reasoning_content'])}, content_len={len(result['content'])}" - else: - detail = f"error: {result['error']}" if result["error"] else "Unknown error" - - print_result(f"模型: {model}", result["success"], detail) - - # 如果有 reasoning 内容,打印前500字符 - if result["has_reasoning"] and result["reasoning_content"]: - print(f"\n 📝 Reasoning 内容预览 (前500字符):") - print(f" {result['reasoning_content'][:500]}...") - - # 6. 总结报告 - print_separator("测试总结报告") - - print(f"📊 模型总数: {len(models)}") - print(f"✅ 可用模型: {len(available_models)}") - print(f"❌ 不可用模型: {len(models) - len(available_models)}") - - print(f"\n📊 流式输出测试:") - streaming_pass = sum(1 for r in streaming_results.values() if r["success"]) - print(f" 通过: {streaming_pass}/{len(streaming_results)}") - - print(f"\n📊 Thinking 模式测试 (简单):") - thinking_pass = sum(1 for r in thinking_results.values() if r["success"]) - thinking_with_reasoning = sum(1 for r in thinking_results.values() if r["has_reasoning"]) - print(f" 通过: {thinking_pass}/{len(thinking_results)}") - print(f" 包含推理内容: {thinking_with_reasoning}/{len(thinking_results)}") - - print(f"\n📊 Thinking 模式测试 (复杂):") - complex_pass = sum(1 for r in complex_thinking_results.values() if r["success"]) - complex_with_reasoning = sum(1 for r in complex_thinking_results.values() if r["has_reasoning"]) - print(f" 通过: {complex_pass}/{len(complex_thinking_results)}") - print(f" 包含推理内容: {complex_with_reasoning}/{len(complex_thinking_results)}") - - # 列出所有错误 - print(f"\n📋 错误详情:") - has_errors = False - - for model, result in streaming_results.items(): - if result["error"]: - has_errors = True - print(f" [流式] {model}: {result['error'][:100]}") - - for model, result in thinking_results.items(): - if result["error"]: - has_errors = True - print(f" [Thinking简单] {model}: {result['error'][:100]}") - - for model, result in complex_thinking_results.items(): - if result["error"]: - has_errors = True - print(f" [Thinking复杂] {model}: {result['error'][:100]}") - - if not has_errors: - print(" 无错误") - - print("\n" + "="*60) - print(" 测试完成") - print("="*60 + "\n") - -def test_single_model_basic(model: str): - """单独测试一个模型的基础功能""" - print_separator(f"基础测试: {model}") - success, detail = test_model_basic(model) - print_result(f"模型: {model}", success, detail) - return success - -def test_single_model_streaming(model: str): - """单独测试一个模型的流式输出""" - print_separator(f"流式测试: {model}") - result = test_streaming(model) - detail = f"chunks={result['chunks']}, content_len={len(result['content'])}" - if result["error"]: - detail = f"error: {result['error']}" - print_result(f"模型: {model}", result["success"], detail) - if result["content"]: - print(f"\n内容: {result['content'][:300]}") - return result - -def test_single_model_thinking(model: str, complex_task: bool = False): - """单独测试一个模型的thinking模式""" - task_type = "复杂" if complex_task else "简单" - print_separator(f"Thinking测试({task_type}): {model}") - result = test_thinking_mode(model, complex_task=complex_task) - detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}" - if result["error"]: - detail = f"error: {result['error']}" - print_result(f"模型: {model}", result["success"], detail) - if result["reasoning_content"]: - print(f"\nReasoning预览: {result['reasoning_content'][:500]}") - if result["content"]: - print(f"\n内容预览: {result['content'][:500]}") - return result - -def print_usage(): - print(""" -用法: python test_api.py [options] - -命令: - models - 获取模型列表 - basic - 测试单个模型基础功能 - stream - 测试单个模型流式输出 - thinking - 测试单个模型thinking模式(简单任务) - thinking-complex - 测试单个模型thinking模式(复杂任务) - all - 运行完整测试(原有功能) - -示例: - python test_api.py models - python test_api.py basic claude-sonnet - python test_api.py stream claude-sonnet - python test_api.py thinking claude-sonnet -""") - -if __name__ == "__main__": - import sys - - if len(sys.argv) < 2: - print_usage() - sys.exit(0) - - cmd = sys.argv[1].lower() - - if cmd == "models": - get_models() - elif cmd == "basic" and len(sys.argv) >= 3: - test_single_model_basic(sys.argv[2]) - elif cmd == "stream" and len(sys.argv) >= 3: - test_single_model_streaming(sys.argv[2]) - elif cmd == "thinking" and len(sys.argv) >= 3: - test_single_model_thinking(sys.argv[2], complex_task=False) - elif cmd == "thinking-complex" and len(sys.argv) >= 3: - test_single_model_thinking(sys.argv[2], complex_task=True) - elif cmd == "all": - run_full_test() - else: - print_usage() diff --git a/test_auth_diff.go b/test_auth_diff.go deleted file mode 100644 index b294622e..00000000 --- a/test_auth_diff.go +++ /dev/null @@ -1,273 +0,0 @@ -// 测试脚本 3:对比 CLIProxyAPIPlus 与官方格式的差异 -// 这个脚本分析 CLIProxyAPIPlus 保存的 token 与官方格式的差异 -// 运行方式: go run test_auth_diff.go -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -func main() { - fmt.Println("=" + strings.Repeat("=", 59)) - fmt.Println(" 测试脚本 3: Token 格式差异分析") - fmt.Println("=" + strings.Repeat("=", 59)) - - homeDir := os.Getenv("USERPROFILE") - - // 加载官方 IDE Token (Kiro IDE 生成) - fmt.Println("\n[1] 官方 Kiro IDE Token 格式") - fmt.Println("-" + strings.Repeat("-", 59)) - - ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") - ideToken := loadAndAnalyze(ideTokenPath, "Kiro IDE") - - // 加载 CLIProxyAPIPlus 保存的 Token - fmt.Println("\n[2] CLIProxyAPIPlus 保存的 Token 格式") - fmt.Println("-" + strings.Repeat("-", 59)) - - cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") - files, _ := os.ReadDir(cliProxyDir) - - var cliProxyTokens []map[string]interface{} - for _, f := range files { - if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") { - p := filepath.Join(cliProxyDir, f.Name()) - token := loadAndAnalyze(p, f.Name()) - if token != nil { - cliProxyTokens = append(cliProxyTokens, token) - } - } - } - - // 对比分析 - fmt.Println("\n[3] 关键差异分析") - fmt.Println("-" + strings.Repeat("-", 59)) - - if ideToken == nil { - fmt.Println("❌ 无法加载 IDE Token,跳过对比") - } else if len(cliProxyTokens) == 0 { - fmt.Println("❌ 无法加载 CLIProxyAPIPlus Token,跳过对比") - } else { - // 对比最新的 CLIProxyAPIPlus token - cliToken := cliProxyTokens[0] - - fmt.Println("\n字段对比:") - fmt.Printf("%-20s | %-15s | %-15s\n", "字段", "IDE Token", "CLIProxy Token") - fmt.Println(strings.Repeat("-", 55)) - - fields := []string{ - "accessToken", "refreshToken", "clientId", "clientSecret", - "authMethod", "auth_method", "provider", "region", "expiresAt", "expires_at", - } - - for _, field := range fields { - ideVal := getFieldStatus(ideToken, field) - cliVal := getFieldStatus(cliToken, field) - - status := " " - if ideVal != cliVal { - if ideVal == "✅ 有" && cliVal == "❌ 无" { - status = "⚠️" - } else if ideVal == "❌ 无" && cliVal == "✅ 有" { - status = "📝" - } - } - - fmt.Printf("%-20s | %-15s | %-15s %s\n", field, ideVal, cliVal, status) - } - - // 关键问题检测 - fmt.Println("\n🔍 问题检测:") - - // 检查 clientId/clientSecret - if hasField(ideToken, "clientId") && !hasField(cliToken, "clientId") { - fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientId 字段!") - fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientId") - } - - if hasField(ideToken, "clientSecret") && !hasField(cliToken, "clientSecret") { - fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientSecret 字段!") - fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientSecret") - } - - // 检查字段名差异 - if hasField(cliToken, "auth_method") && !hasField(cliToken, "authMethod") { - fmt.Println(" 📝 注意: CLIProxy 使用 auth_method (snake_case)") - fmt.Println(" 而官方使用 authMethod (camelCase)") - } - - if hasField(cliToken, "expires_at") && !hasField(cliToken, "expiresAt") { - fmt.Println(" 📝 注意: CLIProxy 使用 expires_at (snake_case)") - fmt.Println(" 而官方使用 expiresAt (camelCase)") - } - } - - // Step 4: 测试使用完整格式的 token - fmt.Println("\n[4] 测试完整格式 Token (带 clientId/clientSecret)") - fmt.Println("-" + strings.Repeat("-", 59)) - - if ideToken != nil { - testWithFullToken(ideToken) - } - - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println(" 分析完成") - fmt.Println(strings.Repeat("=", 60)) - - // 给出建议 - fmt.Println("\n💡 修复建议:") - fmt.Println(" 1. CLIProxyAPIPlus 导入 token 时需要保留 clientId 和 clientSecret") - fmt.Println(" 2. IdC 认证刷新 token 必须使用这两个字段") - fmt.Println(" 3. 检查 CLIProxyAPIPlus 的 token 导入逻辑:") - fmt.Println(" - internal/auth/kiro/aws.go LoadKiroIDEToken()") - fmt.Println(" - sdk/auth/kiro.go ImportFromKiroIDE()") -} - -func loadAndAnalyze(path, name string) map[string]interface{} { - data, err := os.ReadFile(path) - if err != nil { - fmt.Printf("❌ 无法加载 %s: %v\n", name, err) - return nil - } - - var token map[string]interface{} - if err := json.Unmarshal(data, &token); err != nil { - fmt.Printf("❌ 无法解析 %s: %v\n", name, err) - return nil - } - - fmt.Printf("📄 %s\n", path) - fmt.Printf(" 字段数: %d\n", len(token)) - - // 列出所有字段 - fmt.Printf(" 字段列表: ") - keys := make([]string, 0, len(token)) - for k := range token { - keys = append(keys, k) - } - fmt.Printf("%v\n", keys) - - return token -} - -func getFieldStatus(token map[string]interface{}, field string) string { - if token == nil { - return "N/A" - } - if v, ok := token[field]; ok && v != nil && v != "" { - return "✅ 有" - } - return "❌ 无" -} - -func hasField(token map[string]interface{}, field string) bool { - if token == nil { - return false - } - v, ok := token[field] - return ok && v != nil && v != "" -} - -func testWithFullToken(token map[string]interface{}) { - accessToken, _ := token["accessToken"].(string) - refreshToken, _ := token["refreshToken"].(string) - clientId, _ := token["clientId"].(string) - clientSecret, _ := token["clientSecret"].(string) - region, _ := token["region"].(string) - - if region == "" { - region = "us-east-1" - } - - // 测试当前 accessToken - fmt.Println("\n测试当前 accessToken...") - if testAPICall(accessToken, region) { - fmt.Println("✅ 当前 accessToken 有效") - return - } - - fmt.Println("⚠️ 当前 accessToken 无效,尝试刷新...") - - // 检查是否有完整的刷新所需字段 - if clientId == "" || clientSecret == "" { - fmt.Println("❌ 缺少 clientId 或 clientSecret,无法刷新") - fmt.Println(" 这就是问题所在!") - return - } - - // 尝试刷新 - fmt.Println("\n使用完整字段刷新 token...") - url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region) - - requestBody := map[string]interface{}{ - "refreshToken": refreshToken, - "clientId": clientId, - "clientSecret": clientSecret, - "grantType": "refresh_token", - } - - body, _ := json.Marshal(requestBody) - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 请求失败: %v\n", err) - return - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode == 200 { - var refreshResp map[string]interface{} - json.Unmarshal(respBody, &refreshResp) - - newAccessToken, _ := refreshResp["accessToken"].(string) - fmt.Println("✅ Token 刷新成功!") - - // 验证新 token - if testAPICall(newAccessToken, region) { - fmt.Println("✅ 新 Token 验证成功!") - fmt.Println("\n✅ 结论: 使用完整格式 (含 clientId/clientSecret) 可以正常工作") - } - } else { - fmt.Printf("❌ 刷新失败: HTTP %d\n", resp.StatusCode) - fmt.Printf(" 响应: %s\n", string(respBody)) - } -} - -func testAPICall(accessToken, region string) bool { - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == 200 -} diff --git a/test_auth_idc_go1.go b/test_auth_idc_go1.go deleted file mode 100644 index 55fd5829..00000000 --- a/test_auth_idc_go1.go +++ /dev/null @@ -1,323 +0,0 @@ -// 测试脚本 1:模拟 kiro2api_go1 的 IdC 认证方式 -// 这个脚本完整模拟 kiro-gateway/temp/kiro2api_go1 的认证逻辑 -// 运行方式: go run test_auth_idc_go1.go -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "math/rand" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -// 配置常量 - 来自 kiro2api_go1/config/config.go -const ( - IdcRefreshTokenURL = "https://oidc.us-east-1.amazonaws.com/token" - CodeWhispererAPIURL = "https://codewhisperer.us-east-1.amazonaws.com" -) - -// AuthConfig - 来自 kiro2api_go1/auth/config.go -type AuthConfig struct { - AuthType string `json:"auth"` - RefreshToken string `json:"refreshToken"` - ClientID string `json:"clientId,omitempty"` - ClientSecret string `json:"clientSecret,omitempty"` -} - -// IdcRefreshRequest - 来自 kiro2api_go1/types/token.go -type IdcRefreshRequest struct { - ClientId string `json:"clientId"` - ClientSecret string `json:"clientSecret"` - GrantType string `json:"grantType"` - RefreshToken string `json:"refreshToken"` -} - -// RefreshResponse - 来自 kiro2api_go1/types/token.go -type RefreshResponse struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken,omitempty"` - ExpiresIn int `json:"expiresIn"` - TokenType string `json:"tokenType,omitempty"` -} - -// Fingerprint - 简化的指纹结构 -type Fingerprint struct { - OSType string - ConnectionBehavior string - AcceptLanguage string - SecFetchMode string - AcceptEncoding string -} - -func generateFingerprint() *Fingerprint { - osTypes := []string{"darwin", "windows", "linux"} - connections := []string{"keep-alive", "close"} - languages := []string{"en-US,en;q=0.9", "zh-CN,zh;q=0.9", "en-GB,en;q=0.9"} - fetchModes := []string{"cors", "navigate", "no-cors"} - - return &Fingerprint{ - OSType: osTypes[rand.Intn(len(osTypes))], - ConnectionBehavior: connections[rand.Intn(len(connections))], - AcceptLanguage: languages[rand.Intn(len(languages))], - SecFetchMode: fetchModes[rand.Intn(len(fetchModes))], - AcceptEncoding: "gzip, deflate, br", - } -} - -func main() { - rand.Seed(time.Now().UnixNano()) - - fmt.Println("=" + strings.Repeat("=", 59)) - fmt.Println(" 测试脚本 1: kiro2api_go1 风格 IdC 认证") - fmt.Println("=" + strings.Repeat("=", 59)) - - // Step 1: 加载官方格式的 token 文件 - fmt.Println("\n[Step 1] 加载官方格式 Token 文件") - fmt.Println("-" + strings.Repeat("-", 59)) - - // 尝试从多个位置加载 - tokenPaths := []string{ - // 优先使用包含完整 clientId/clientSecret 的文件 - "E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json", - filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"), - } - - var tokenData map[string]interface{} - var loadedPath string - - for _, p := range tokenPaths { - data, err := os.ReadFile(p) - if err == nil { - if err := json.Unmarshal(data, &tokenData); err == nil { - loadedPath = p - break - } - } - } - - if tokenData == nil { - fmt.Println("❌ 无法加载任何 token 文件") - return - } - - fmt.Printf("✅ 加载文件: %s\n", loadedPath) - - // 提取关键字段 - accessToken, _ := tokenData["accessToken"].(string) - refreshToken, _ := tokenData["refreshToken"].(string) - clientId, _ := tokenData["clientId"].(string) - clientSecret, _ := tokenData["clientSecret"].(string) - authMethod, _ := tokenData["authMethod"].(string) - region, _ := tokenData["region"].(string) - - if region == "" { - region = "us-east-1" - } - - fmt.Printf("\n当前 Token 信息:\n") - fmt.Printf(" AuthMethod: %s\n", authMethod) - fmt.Printf(" Region: %s\n", region) - fmt.Printf(" AccessToken: %s...\n", truncate(accessToken, 50)) - fmt.Printf(" RefreshToken: %s...\n", truncate(refreshToken, 50)) - fmt.Printf(" ClientID: %s\n", truncate(clientId, 30)) - fmt.Printf(" ClientSecret: %s...\n", truncate(clientSecret, 50)) - - // Step 2: 验证 IdC 认证所需字段 - fmt.Println("\n[Step 2] 验证 IdC 认证必需字段") - fmt.Println("-" + strings.Repeat("-", 59)) - - missingFields := []string{} - if refreshToken == "" { - missingFields = append(missingFields, "refreshToken") - } - if clientId == "" { - missingFields = append(missingFields, "clientId") - } - if clientSecret == "" { - missingFields = append(missingFields, "clientSecret") - } - - if len(missingFields) > 0 { - fmt.Printf("❌ 缺少必需字段: %v\n", missingFields) - fmt.Println(" IdC 认证需要: refreshToken, clientId, clientSecret") - return - } - fmt.Println("✅ 所有必需字段都存在") - - // Step 3: 测试直接使用 accessToken 调用 API - fmt.Println("\n[Step 3] 测试当前 AccessToken 有效性") - fmt.Println("-" + strings.Repeat("-", 59)) - - if testAPICall(accessToken, region) { - fmt.Println("✅ 当前 AccessToken 有效,无需刷新") - } else { - fmt.Println("⚠️ 当前 AccessToken 无效,需要刷新") - - // Step 4: 使用 kiro2api_go1 风格刷新 token - fmt.Println("\n[Step 4] 使用 kiro2api_go1 风格刷新 Token") - fmt.Println("-" + strings.Repeat("-", 59)) - - newToken, err := refreshIdCToken(AuthConfig{ - AuthType: "IdC", - RefreshToken: refreshToken, - ClientID: clientId, - ClientSecret: clientSecret, - }, region) - - if err != nil { - fmt.Printf("❌ 刷新失败: %v\n", err) - return - } - - fmt.Println("✅ Token 刷新成功!") - fmt.Printf(" 新 AccessToken: %s...\n", truncate(newToken.AccessToken, 50)) - fmt.Printf(" ExpiresIn: %d 秒\n", newToken.ExpiresIn) - - // Step 5: 验证新 token - fmt.Println("\n[Step 5] 验证新 Token") - fmt.Println("-" + strings.Repeat("-", 59)) - - if testAPICall(newToken.AccessToken, region) { - fmt.Println("✅ 新 Token 验证成功!") - - // 保存新 token - saveNewToken(loadedPath, newToken, tokenData) - } else { - fmt.Println("❌ 新 Token 验证失败") - } - } - - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println(" 测试完成") - fmt.Println(strings.Repeat("=", 60)) -} - -// refreshIdCToken - 完全模拟 kiro2api_go1/auth/refresh.go 的 refreshIdCToken 函数 -func refreshIdCToken(authConfig AuthConfig, region string) (*RefreshResponse, error) { - refreshReq := IdcRefreshRequest{ - ClientId: authConfig.ClientID, - ClientSecret: authConfig.ClientSecret, - GrantType: "refresh_token", - RefreshToken: authConfig.RefreshToken, - } - - reqBody, err := json.Marshal(refreshReq) - if err != nil { - return nil, fmt.Errorf("序列化IdC请求失败: %v", err) - } - - url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, fmt.Errorf("创建IdC请求失败: %v", err) - } - - // 设置 IdC 特殊 headers(使用指纹随机化)- 完全模拟 kiro2api_go1 - fp := generateFingerprint() - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) - req.Header.Set("Connection", fp.ConnectionBehavior) - req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/3.738.0 ua/2.1 os/%s lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE", fp.OSType)) - req.Header.Set("Accept", "*/*") - req.Header.Set("Accept-Language", fp.AcceptLanguage) - req.Header.Set("sec-fetch-mode", fp.SecFetchMode) - req.Header.Set("User-Agent", "node") - req.Header.Set("Accept-Encoding", fp.AcceptEncoding) - - fmt.Println("发送刷新请求:") - fmt.Printf(" URL: %s\n", url) - fmt.Println(" Headers:") - for k, v := range req.Header { - if k == "Content-Type" || k == "Host" || k == "X-Amz-User-Agent" || k == "User-Agent" { - fmt.Printf(" %s: %s\n", k, v[0]) - } - } - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("IdC请求失败: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("IdC刷新失败: 状态码 %d, 响应: %s", resp.StatusCode, string(body)) - } - - var refreshResp RefreshResponse - if err := json.Unmarshal(body, &refreshResp); err != nil { - return nil, fmt.Errorf("解析IdC响应失败: %v", err) - } - - return &refreshResp, nil -} - -func testAPICall(accessToken, region string) bool { - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf(" 请求错误: %v\n", err) - return false - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - fmt.Printf(" API 响应: HTTP %d\n", resp.StatusCode) - - if resp.StatusCode == 200 { - return true - } - - fmt.Printf(" 错误详情: %s\n", truncate(string(respBody), 200)) - return false -} - -func saveNewToken(originalPath string, newToken *RefreshResponse, originalData map[string]interface{}) { - // 更新 token 数据 - originalData["accessToken"] = newToken.AccessToken - if newToken.RefreshToken != "" { - originalData["refreshToken"] = newToken.RefreshToken - } - originalData["expiresAt"] = time.Now().Add(time.Duration(newToken.ExpiresIn) * time.Second).Format(time.RFC3339) - - data, _ := json.MarshalIndent(originalData, "", " ") - - // 保存到新文件 - newPath := strings.TrimSuffix(originalPath, ".json") + "_refreshed.json" - if err := os.WriteFile(newPath, data, 0644); err != nil { - fmt.Printf("⚠️ 保存失败: %v\n", err) - } else { - fmt.Printf("✅ 新 Token 已保存到: %s\n", newPath) - } -} - -func truncate(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] -} diff --git a/test_auth_js_style.go b/test_auth_js_style.go deleted file mode 100644 index 6ded3305..00000000 --- a/test_auth_js_style.go +++ /dev/null @@ -1,237 +0,0 @@ -// 测试脚本 2:模拟 kiro2Api_js 的认证方式 -// 这个脚本完整模拟 kiro-gateway/temp/kiro2Api_js 的认证逻辑 -// 运行方式: go run test_auth_js_style.go -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -// 常量 - 来自 kiro2Api_js/src/kiro/auth.js -const ( - REFRESH_URL_TEMPLATE = "https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken" - REFRESH_IDC_URL_TEMPLATE = "https://oidc.{{region}}.amazonaws.com/token" - AUTH_METHOD_SOCIAL = "social" - AUTH_METHOD_IDC = "IdC" -) - -func main() { - fmt.Println("=" + strings.Repeat("=", 59)) - fmt.Println(" 测试脚本 2: kiro2Api_js 风格认证") - fmt.Println("=" + strings.Repeat("=", 59)) - - // Step 1: 加载 token 文件 - fmt.Println("\n[Step 1] 加载 Token 文件") - fmt.Println("-" + strings.Repeat("-", 59)) - - tokenPaths := []string{ - filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"), - "E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json", - } - - var tokenData map[string]interface{} - var loadedPath string - - for _, p := range tokenPaths { - data, err := os.ReadFile(p) - if err == nil { - if err := json.Unmarshal(data, &tokenData); err == nil { - loadedPath = p - break - } - } - } - - if tokenData == nil { - fmt.Println("❌ 无法加载任何 token 文件") - return - } - - fmt.Printf("✅ 加载文件: %s\n", loadedPath) - - // 提取字段 - 模拟 kiro2Api_js/src/kiro/auth.js initializeAuth - accessToken, _ := tokenData["accessToken"].(string) - refreshToken, _ := tokenData["refreshToken"].(string) - clientId, _ := tokenData["clientId"].(string) - clientSecret, _ := tokenData["clientSecret"].(string) - authMethod, _ := tokenData["authMethod"].(string) - region, _ := tokenData["region"].(string) - - if region == "" { - region = "us-east-1" - fmt.Println("⚠️ Region 未设置,使用默认值 us-east-1") - } - - fmt.Printf("\nToken 信息:\n") - fmt.Printf(" AuthMethod: %s\n", authMethod) - fmt.Printf(" Region: %s\n", region) - fmt.Printf(" 有 ClientID: %v\n", clientId != "") - fmt.Printf(" 有 ClientSecret: %v\n", clientSecret != "") - - // Step 2: 测试当前 token - fmt.Println("\n[Step 2] 测试当前 AccessToken") - fmt.Println("-" + strings.Repeat("-", 59)) - - if testAPI(accessToken, region) { - fmt.Println("✅ 当前 AccessToken 有效") - return - } - - fmt.Println("⚠️ 当前 AccessToken 无效,开始刷新...") - - // Step 3: 根据 authMethod 选择刷新方式 - 模拟 doRefreshToken - fmt.Println("\n[Step 3] 刷新 Token (JS 风格)") - fmt.Println("-" + strings.Repeat("-", 59)) - - var refreshURL string - var requestBody map[string]interface{} - - // 判断认证方式 - 模拟 kiro2Api_js auth.js doRefreshToken - if authMethod == AUTH_METHOD_SOCIAL { - // Social 认证 - refreshURL = strings.Replace(REFRESH_URL_TEMPLATE, "{{region}}", region, 1) - requestBody = map[string]interface{}{ - "refreshToken": refreshToken, - } - fmt.Println("使用 Social 认证方式") - } else { - // IdC 认证 (默认) - refreshURL = strings.Replace(REFRESH_IDC_URL_TEMPLATE, "{{region}}", region, 1) - requestBody = map[string]interface{}{ - "refreshToken": refreshToken, - "clientId": clientId, - "clientSecret": clientSecret, - "grantType": "refresh_token", - } - fmt.Println("使用 IdC 认证方式") - } - - fmt.Printf("刷新 URL: %s\n", refreshURL) - fmt.Printf("请求字段: %v\n", getKeys(requestBody)) - - // 发送刷新请求 - body, _ := json.Marshal(requestBody) - req, _ := http.NewRequest("POST", refreshURL, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 请求失败: %v\n", err) - return - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - fmt.Printf("\n响应状态: HTTP %d\n", resp.StatusCode) - - if resp.StatusCode != 200 { - fmt.Printf("❌ 刷新失败: %s\n", string(respBody)) - - // 分析错误 - var errResp map[string]interface{} - if err := json.Unmarshal(respBody, &errResp); err == nil { - if errType, ok := errResp["error"].(string); ok { - fmt.Printf("错误类型: %s\n", errType) - if errType == "invalid_grant" { - fmt.Println("\n💡 提示: refresh_token 可能已过期,需要重新授权") - } - } - if errDesc, ok := errResp["error_description"].(string); ok { - fmt.Printf("错误描述: %s\n", errDesc) - } - } - return - } - - // 解析响应 - var refreshResp map[string]interface{} - json.Unmarshal(respBody, &refreshResp) - - newAccessToken, _ := refreshResp["accessToken"].(string) - newRefreshToken, _ := refreshResp["refreshToken"].(string) - expiresIn, _ := refreshResp["expiresIn"].(float64) - - fmt.Println("✅ Token 刷新成功!") - fmt.Printf(" 新 AccessToken: %s...\n", truncate(newAccessToken, 50)) - fmt.Printf(" ExpiresIn: %.0f 秒\n", expiresIn) - if newRefreshToken != "" { - fmt.Printf(" 新 RefreshToken: %s...\n", truncate(newRefreshToken, 50)) - } - - // Step 4: 验证新 token - fmt.Println("\n[Step 4] 验证新 Token") - fmt.Println("-" + strings.Repeat("-", 59)) - - if testAPI(newAccessToken, region) { - fmt.Println("✅ 新 Token 验证成功!") - - // 保存新 token - 模拟 saveCredentialsToFile - tokenData["accessToken"] = newAccessToken - if newRefreshToken != "" { - tokenData["refreshToken"] = newRefreshToken - } - tokenData["expiresAt"] = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) - - saveData, _ := json.MarshalIndent(tokenData, "", " ") - newPath := strings.TrimSuffix(loadedPath, ".json") + "_js_refreshed.json" - os.WriteFile(newPath, saveData, 0644) - fmt.Printf("✅ 已保存到: %s\n", newPath) - } else { - fmt.Println("❌ 新 Token 验证失败") - } - - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println(" 测试完成") - fmt.Println(strings.Repeat("=", 60)) -} - -func testAPI(accessToken, region string) bool { - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - - return resp.StatusCode == 200 -} - -func getKeys(m map[string]interface{}) []string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - return keys -} - -func truncate(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] -} diff --git a/test_kiro_debug.go b/test_kiro_debug.go deleted file mode 100644 index 0fbbed6c..00000000 --- a/test_kiro_debug.go +++ /dev/null @@ -1,348 +0,0 @@ -// 独立测试脚本:排查 Kiro Token 403 错误 -// 运行方式: go run test_kiro_debug.go -package main - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -// Token 结构 - 匹配 Kiro IDE 格式 -type KiroIDEToken struct { - AccessToken string `json:"accessToken"` - RefreshToken string `json:"refreshToken"` - ExpiresAt string `json:"expiresAt"` - ClientIDHash string `json:"clientIdHash,omitempty"` - AuthMethod string `json:"authMethod"` - Provider string `json:"provider"` - Region string `json:"region,omitempty"` -} - -// Token 结构 - 匹配 CLIProxyAPIPlus 格式 -type CLIProxyToken struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ProfileArn string `json:"profile_arn"` - ExpiresAt string `json:"expires_at"` - AuthMethod string `json:"auth_method"` - Provider string `json:"provider"` - ClientID string `json:"client_id,omitempty"` - ClientSecret string `json:"client_secret,omitempty"` - Email string `json:"email,omitempty"` - Type string `json:"type"` -} - -func main() { - fmt.Println("=" + strings.Repeat("=", 59)) - fmt.Println(" Kiro Token 403 错误排查工具") - fmt.Println("=" + strings.Repeat("=", 59)) - - homeDir, _ := os.UserHomeDir() - - // Step 1: 检查 Kiro IDE Token 文件 - fmt.Println("\n[Step 1] 检查 Kiro IDE Token 文件") - fmt.Println("-" + strings.Repeat("-", 59)) - - ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") - ideToken, err := loadKiroIDEToken(ideTokenPath) - if err != nil { - fmt.Printf("❌ 无法加载 Kiro IDE Token: %v\n", err) - return - } - fmt.Printf("✅ Token 文件: %s\n", ideTokenPath) - fmt.Printf(" AuthMethod: %s\n", ideToken.AuthMethod) - fmt.Printf(" Provider: %s\n", ideToken.Provider) - fmt.Printf(" Region: %s\n", ideToken.Region) - fmt.Printf(" ExpiresAt: %s\n", ideToken.ExpiresAt) - fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(ideToken.AccessToken, 50)) - - // Step 2: 检查 Token 过期状态 - fmt.Println("\n[Step 2] 检查 Token 过期状态") - fmt.Println("-" + strings.Repeat("-", 59)) - - expiresAt, err := parseExpiresAt(ideToken.ExpiresAt) - if err != nil { - fmt.Printf("❌ 无法解析过期时间: %v\n", err) - } else { - now := time.Now() - if now.After(expiresAt) { - fmt.Printf("❌ Token 已过期!过期时间: %s,当前时间: %s\n", expiresAt.Format(time.RFC3339), now.Format(time.RFC3339)) - } else { - remaining := expiresAt.Sub(now) - fmt.Printf("✅ Token 未过期,剩余: %s\n", remaining.Round(time.Second)) - } - } - - // Step 3: 检查 CLIProxyAPIPlus 保存的 Token - fmt.Println("\n[Step 3] 检查 CLIProxyAPIPlus 保存的 Token") - fmt.Println("-" + strings.Repeat("-", 59)) - - cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") - files, _ := os.ReadDir(cliProxyDir) - for _, f := range files { - if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") { - filePath := filepath.Join(cliProxyDir, f.Name()) - cliToken, err := loadCLIProxyToken(filePath) - if err != nil { - fmt.Printf("❌ %s: 加载失败 - %v\n", f.Name(), err) - continue - } - fmt.Printf("📄 %s:\n", f.Name()) - fmt.Printf(" AuthMethod: %s\n", cliToken.AuthMethod) - fmt.Printf(" Provider: %s\n", cliToken.Provider) - fmt.Printf(" ExpiresAt: %s\n", cliToken.ExpiresAt) - fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(cliToken.AccessToken, 50)) - - // 比较 Token - if cliToken.AccessToken == ideToken.AccessToken { - fmt.Printf(" ✅ AccessToken 与 IDE Token 一致\n") - } else { - fmt.Printf(" ⚠️ AccessToken 与 IDE Token 不一致!\n") - } - } - } - - // Step 4: 直接测试 Token 有效性 (调用 Kiro API) - fmt.Println("\n[Step 4] 直接测试 Token 有效性") - fmt.Println("-" + strings.Repeat("-", 59)) - - testTokenValidity(ideToken.AccessToken, ideToken.Region) - - // Step 5: 测试不同的请求头格式 - fmt.Println("\n[Step 5] 测试不同的请求头格式") - fmt.Println("-" + strings.Repeat("-", 59)) - - testDifferentHeaders(ideToken.AccessToken, ideToken.Region) - - // Step 6: 解析 JWT 内容 - fmt.Println("\n[Step 6] 解析 JWT Token 内容") - fmt.Println("-" + strings.Repeat("-", 59)) - - parseJWT(ideToken.AccessToken) - - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println(" 排查完成") - fmt.Println(strings.Repeat("=", 60)) -} - -func loadKiroIDEToken(path string) (*KiroIDEToken, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var token KiroIDEToken - if err := json.Unmarshal(data, &token); err != nil { - return nil, err - } - return &token, nil -} - -func loadCLIProxyToken(path string) (*CLIProxyToken, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - var token CLIProxyToken - if err := json.Unmarshal(data, &token); err != nil { - return nil, err - } - return &token, nil -} - -func parseExpiresAt(s string) (time.Time, error) { - formats := []string{ - time.RFC3339, - "2006-01-02T15:04:05.000Z", - "2006-01-02T15:04:05Z", - } - for _, f := range formats { - if t, err := time.Parse(f, s); err == nil { - return t, nil - } - } - return time.Time{}, fmt.Errorf("无法解析时间格式: %s", s) -} - -func truncate(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] -} - -func testTokenValidity(accessToken, region string) { - if region == "" { - region = "us-east-1" - } - - // 测试 GetUsageLimits API - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - fmt.Printf("请求 URL: %s\n", url) - fmt.Printf("请求头:\n") - for k, v := range req.Header { - if k == "Authorization" { - fmt.Printf(" %s: Bearer %s...\n", k, truncate(v[0][7:], 30)) - } else { - fmt.Printf(" %s: %s\n", k, v[0]) - } - } - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 请求失败: %v\n", err) - return - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - fmt.Printf("响应状态: %d\n", resp.StatusCode) - fmt.Printf("响应内容: %s\n", string(respBody)) - - if resp.StatusCode == 200 { - fmt.Println("✅ Token 有效!") - } else if resp.StatusCode == 403 { - fmt.Println("❌ Token 无效或已过期 (403)") - } -} - -func testDifferentHeaders(accessToken, region string) { - if region == "" { - region = "us-east-1" - } - - tests := []struct { - name string - headers map[string]string - }{ - { - name: "最小请求头", - headers: map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer " + accessToken, - }, - }, - { - name: "模拟 kiro2api_go1 风格", - headers: map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Authorization": "Bearer " + accessToken, - "x-amzn-kiro-agent-mode": "vibe", - "x-amzn-codewhisperer-optout": "true", - "amz-sdk-invocation-id": "test-invocation-id", - "amz-sdk-request": "attempt=1; max=3", - "x-amz-user-agent": "aws-sdk-js/1.0.27 KiroIDE-0.8.0-abc123", - "User-Agent": "aws-sdk-js/1.0.27 ua/2.1 os/windows#10.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.27 m/E KiroIDE-0.8.0-abc123", - }, - }, - { - name: "模拟 CLIProxyAPIPlus 风格", - headers: map[string]string{ - "Content-Type": "application/x-amz-json-1.0", - "x-amz-target": "AmazonCodeWhispererService.GetUsageLimits", - "Authorization": "Bearer " + accessToken, - "Accept": "application/json", - "amz-sdk-invocation-id": "test-invocation-id", - "amz-sdk-request": "attempt=1; max=1", - "Connection": "close", - }, - }, - } - - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - for _, test := range tests { - fmt.Printf("\n测试: %s\n", test.name) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - for k, v := range test.headers { - req.Header.Set(k, v) - } - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf(" ❌ 请求失败: %v\n", err) - continue - } - - respBody, _ := io.ReadAll(resp.Body) - resp.Body.Close() - - if resp.StatusCode == 200 { - fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) - } else { - fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) - } - } -} - -func parseJWT(token string) { - parts := strings.Split(token, ".") - if len(parts) < 2 { - fmt.Println("Token 不是 JWT 格式") - return - } - - // 解码 header - headerData, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - fmt.Printf("无法解码 JWT header: %v\n", err) - } else { - var header map[string]interface{} - json.Unmarshal(headerData, &header) - fmt.Printf("JWT Header: %v\n", header) - } - - // 解码 payload - payloadData, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - fmt.Printf("无法解码 JWT payload: %v\n", err) - } else { - var payload map[string]interface{} - json.Unmarshal(payloadData, &payload) - fmt.Printf("JWT Payload:\n") - for k, v := range payload { - fmt.Printf(" %s: %v\n", k, v) - } - - // 检查过期时间 - if exp, ok := payload["exp"].(float64); ok { - expTime := time.Unix(int64(exp), 0) - if time.Now().After(expTime) { - fmt.Printf(" ⚠️ JWT 已过期! exp=%s\n", expTime.Format(time.RFC3339)) - } else { - fmt.Printf(" ✅ JWT 未过期, 剩余: %s\n", expTime.Sub(time.Now()).Round(time.Second)) - } - } - } -} diff --git a/test_proxy_debug.go b/test_proxy_debug.go deleted file mode 100644 index 82369e74..00000000 --- a/test_proxy_debug.go +++ /dev/null @@ -1,367 +0,0 @@ -// 测试脚本 2:通过 CLIProxyAPIPlus 代理层排查问题 -// 运行方式: go run test_proxy_debug.go -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -const ( - ProxyURL = "http://localhost:8317" - APIKey = "your-api-key-1" -) - -func main() { - fmt.Println("=" + strings.Repeat("=", 59)) - fmt.Println(" CLIProxyAPIPlus 代理层问题排查") - fmt.Println("=" + strings.Repeat("=", 59)) - - // Step 1: 检查代理服务状态 - fmt.Println("\n[Step 1] 检查代理服务状态") - fmt.Println("-" + strings.Repeat("-", 59)) - - resp, err := http.Get(ProxyURL + "/health") - if err != nil { - fmt.Printf("❌ 代理服务不可达: %v\n", err) - fmt.Println("请确保服务正在运行: go run ./cmd/server/main.go") - return - } - resp.Body.Close() - fmt.Printf("✅ 代理服务正常 (HTTP %d)\n", resp.StatusCode) - - // Step 2: 获取模型列表 - fmt.Println("\n[Step 2] 获取模型列表") - fmt.Println("-" + strings.Repeat("-", 59)) - - models := getModels() - if len(models) == 0 { - fmt.Println("❌ 没有可用的模型,检查凭据加载") - checkCredentials() - return - } - fmt.Printf("✅ 找到 %d 个模型:\n", len(models)) - for _, m := range models { - fmt.Printf(" - %s\n", m) - } - - // Step 3: 测试模型请求 - 捕获详细错误 - fmt.Println("\n[Step 3] 测试模型请求(详细日志)") - fmt.Println("-" + strings.Repeat("-", 59)) - - if len(models) > 0 { - testModel := models[0] - testModelRequest(testModel) - } - - // Step 4: 检查代理内部 Token 状态 - fmt.Println("\n[Step 4] 检查代理服务加载的凭据") - fmt.Println("-" + strings.Repeat("-", 59)) - - checkProxyCredentials() - - // Step 5: 对比直接请求和代理请求 - fmt.Println("\n[Step 5] 对比直接请求 vs 代理请求") - fmt.Println("-" + strings.Repeat("-", 59)) - - compareDirectVsProxy() - - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println(" 排查完成") - fmt.Println(strings.Repeat("=", 60)) -} - -func getModels() []string { - req, _ := http.NewRequest("GET", ProxyURL+"/v1/models", nil) - req.Header.Set("Authorization", "Bearer "+APIKey) - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 请求失败: %v\n", err) - return nil - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != 200 { - fmt.Printf("❌ HTTP %d: %s\n", resp.StatusCode, string(body)) - return nil - } - - var result struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - json.Unmarshal(body, &result) - - models := make([]string, len(result.Data)) - for i, m := range result.Data { - models[i] = m.ID - } - return models -} - -func checkCredentials() { - homeDir, _ := os.UserHomeDir() - cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") - - fmt.Printf("\n检查凭据目录: %s\n", cliProxyDir) - files, err := os.ReadDir(cliProxyDir) - if err != nil { - fmt.Printf("❌ 无法读取目录: %v\n", err) - return - } - - for _, f := range files { - if strings.HasSuffix(f.Name(), ".json") { - fmt.Printf(" 📄 %s\n", f.Name()) - } - } -} - -func testModelRequest(model string) { - fmt.Printf("测试模型: %s\n", model) - - payload := map[string]interface{}{ - "model": model, - "messages": []map[string]string{ - {"role": "user", "content": "Say 'OK' if you receive this."}, - }, - "max_tokens": 50, - "stream": false, - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+APIKey) - req.Header.Set("Content-Type", "application/json") - - fmt.Println("\n发送请求:") - fmt.Printf(" URL: %s/v1/chat/completions\n", ProxyURL) - fmt.Printf(" Model: %s\n", model) - - client := &http.Client{Timeout: 60 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 请求失败: %v\n", err) - return - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - fmt.Printf("\n响应:\n") - fmt.Printf(" Status: %d\n", resp.StatusCode) - fmt.Printf(" Headers:\n") - for k, v := range resp.Header { - fmt.Printf(" %s: %s\n", k, strings.Join(v, ", ")) - } - - // 格式化 JSON 输出 - var prettyJSON bytes.Buffer - if err := json.Indent(&prettyJSON, respBody, " ", " "); err == nil { - fmt.Printf(" Body:\n %s\n", prettyJSON.String()) - } else { - fmt.Printf(" Body: %s\n", string(respBody)) - } - - if resp.StatusCode == 200 { - fmt.Println("\n✅ 请求成功!") - } else { - fmt.Println("\n❌ 请求失败!分析错误原因...") - analyzeError(respBody) - } -} - -func analyzeError(body []byte) { - var errResp struct { - Message string `json:"message"` - Reason string `json:"reason"` - Error struct { - Message string `json:"message"` - Type string `json:"type"` - } `json:"error"` - } - json.Unmarshal(body, &errResp) - - if errResp.Message != "" { - fmt.Printf("错误消息: %s\n", errResp.Message) - } - if errResp.Reason != "" { - fmt.Printf("错误原因: %s\n", errResp.Reason) - } - if errResp.Error.Message != "" { - fmt.Printf("错误详情: %s (类型: %s)\n", errResp.Error.Message, errResp.Error.Type) - } - - // 分析常见错误 - bodyStr := string(body) - if strings.Contains(bodyStr, "bearer token") || strings.Contains(bodyStr, "invalid") { - fmt.Println("\n可能的原因:") - fmt.Println(" 1. Token 已过期 - 需要刷新") - fmt.Println(" 2. Token 格式不正确 - 检查凭据文件") - fmt.Println(" 3. 代理服务加载了旧的 Token") - } -} - -func checkProxyCredentials() { - // 尝试通过管理 API 获取凭据状态 - req, _ := http.NewRequest("GET", ProxyURL+"/v0/management/auth/list", nil) - // 使用配置中的管理密钥 admin123 - req.Header.Set("Authorization", "Bearer admin123") - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf("❌ 无法访问管理 API: %v\n", err) - return - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode == 200 { - fmt.Println("管理 API 返回的凭据列表:") - var prettyJSON bytes.Buffer - if err := json.Indent(&prettyJSON, body, " ", " "); err == nil { - fmt.Printf("%s\n", prettyJSON.String()) - } else { - fmt.Printf("%s\n", string(body)) - } - } else { - fmt.Printf("管理 API 返回: HTTP %d\n", resp.StatusCode) - fmt.Printf("响应: %s\n", truncate(string(body), 200)) - } -} - -func compareDirectVsProxy() { - homeDir, _ := os.UserHomeDir() - tokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") - - data, err := os.ReadFile(tokenPath) - if err != nil { - fmt.Printf("❌ 无法读取 Token 文件: %v\n", err) - return - } - - var token struct { - AccessToken string `json:"accessToken"` - Region string `json:"region"` - } - json.Unmarshal(data, &token) - - if token.Region == "" { - token.Region = "us-east-1" - } - - // 直接请求 - fmt.Println("\n1. 直接请求 Kiro API:") - directSuccess := testDirectKiroAPI(token.AccessToken, token.Region) - - // 通过代理请求 - fmt.Println("\n2. 通过代理请求:") - proxySuccess := testProxyAPI() - - // 结论 - fmt.Println("\n结论:") - if directSuccess && !proxySuccess { - fmt.Println(" ⚠️ 直接请求成功,代理请求失败") - fmt.Println(" 问题在于 CLIProxyAPIPlus 代理层") - fmt.Println(" 可能原因:") - fmt.Println(" 1. 代理服务使用了过期的 Token") - fmt.Println(" 2. Token 刷新逻辑有问题") - fmt.Println(" 3. 请求头构造不正确") - } else if directSuccess && proxySuccess { - fmt.Println(" ✅ 两者都成功") - } else if !directSuccess && !proxySuccess { - fmt.Println(" ❌ 两者都失败 - Token 本身可能有问题") - } -} - -func testDirectKiroAPI(accessToken, region string) bool { - url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - "isEmailRequired": true, - "resourceType": "AGENTIC_REQUEST", - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf(" ❌ 请求失败: %v\n", err) - return false - } - defer resp.Body.Close() - - if resp.StatusCode == 200 { - fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) - return true - } - respBody, _ := io.ReadAll(resp.Body) - fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) - return false -} - -func testProxyAPI() bool { - models := getModels() - if len(models) == 0 { - fmt.Println(" ❌ 没有可用模型") - return false - } - - payload := map[string]interface{}{ - "model": models[0], - "messages": []map[string]string{ - {"role": "user", "content": "Say OK"}, - }, - "max_tokens": 10, - "stream": false, - } - body, _ := json.Marshal(payload) - - req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+APIKey) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 60 * time.Second} - resp, err := client.Do(req) - if err != nil { - fmt.Printf(" ❌ 请求失败: %v\n", err) - return false - } - defer resp.Body.Close() - - if resp.StatusCode == 200 { - fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) - return true - } - respBody, _ := io.ReadAll(resp.Body) - fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) - return false -} - -func truncate(s string, n int) string { - if len(s) <= n { - return s - } - return s[:n] + "..." -} From 194f66ca9c1c7abfbed0c1e8874b5ed6d9ba9ec3 Mon Sep 17 00:00:00 2001 From: "yuechenglong.5" Date: Wed, 21 Jan 2026 11:03:07 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat(kiro):=20=E6=B7=BB=E5=8A=A0=E5=90=8E?= =?UTF-8?q?=E5=8F=B0=E4=BB=A4=E7=89=8C=E5=88=B7=E6=96=B0=E9=80=9A=E7=9F=A5?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 BackgroundRefresher 中添加 onTokenRefreshed 回调函数和并发安全锁 - 实现 WithOnTokenRefreshed 选项函数用于设置刷新成功回调 - 在 RefreshManager 中添加 SetOnTokenRefreshed 方法支持运行时更新回调 - 为 KiroExecutor 添加 reloadAuthFromFile 方法实现文件重新加载回退机制 - 在 Watcher 中实现 NotifyTokenRefreshed 方法处理刷新通知并更新内存Auth对象 - 通过 Service.GetWatcher 连接刷新器回调到 Watcher 通知链路 - 添加方案A和方案B双重保障解决后台刷新与内存对象时间差问题 --- internal/auth/kiro/background_refresh.go | 48 +++++- internal/auth/kiro/refresh_manager.go | 50 ++++-- internal/runtime/executor/kiro_executor.go | 183 ++++++++++++++++++--- internal/watcher/watcher.go | 108 ++++++++++++ sdk/cliproxy/service.go | 22 +++ sdk/cliproxy/types.go | 14 ++ sdk/cliproxy/watcher.go | 3 + 7 files changed, 386 insertions(+), 42 deletions(-) diff --git a/internal/auth/kiro/background_refresh.go b/internal/auth/kiro/background_refresh.go index 3fecc417..1203ff47 100644 --- a/internal/auth/kiro/background_refresh.go +++ b/internal/auth/kiro/background_refresh.go @@ -50,14 +50,16 @@ func WithConcurrency(concurrency int) RefresherOption { } type BackgroundRefresher struct { - interval time.Duration - batchSize int - concurrency int - tokenRepo TokenRepository - stopCh chan struct{} - wg sync.WaitGroup - oauth *KiroOAuth - ssoClient *SSOOIDCClient + interval time.Duration + batchSize int + concurrency int + tokenRepo TokenRepository + stopCh chan struct{} + wg sync.WaitGroup + oauth *KiroOAuth + ssoClient *SSOOIDCClient + callbackMu sync.RWMutex // 保护回调函数的并发访问 + onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调 } func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher { @@ -84,6 +86,17 @@ func WithConfig(cfg *config.Config) RefresherOption { } } +// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed. +// The callback receives the token ID (filename) and the new token data. +// This allows external components (e.g., Watcher) to be notified of token updates. +func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption { + return func(r *BackgroundRefresher) { + r.callbackMu.Lock() + r.onTokenRefreshed = callback + r.callbackMu.Unlock() + } +} + func (r *BackgroundRefresher) Start(ctx context.Context) { r.wg.Add(1) go func() { @@ -188,5 +201,24 @@ func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) { if err := r.tokenRepo.UpdateToken(token); err != nil { log.Printf("failed to update token %s: %v", token.ID, err) + return + } + + // 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象 + r.callbackMu.RLock() + callback := r.onTokenRefreshed + r.callbackMu.RUnlock() + + if callback != nil { + // 使用 defer recover 隔离回调 panic,防止崩溃整个进程 + func() { + defer func() { + if rec := recover(); rec != nil { + log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec) + } + }() + log.Printf("background refresh: notifying token refresh callback for %s", token.ID) + callback(token.ID, newTokenData) + }() } } diff --git a/internal/auth/kiro/refresh_manager.go b/internal/auth/kiro/refresh_manager.go index cd27b432..05e27a54 100644 --- a/internal/auth/kiro/refresh_manager.go +++ b/internal/auth/kiro/refresh_manager.go @@ -11,11 +11,12 @@ import ( // RefreshManager 是后台刷新器的单例管理器 type RefreshManager struct { - mu sync.Mutex - refresher *BackgroundRefresher - ctx context.Context - cancel context.CancelFunc - started bool + mu sync.Mutex + refresher *BackgroundRefresher + ctx context.Context + cancel context.CancelFunc + started bool + onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调 } var ( @@ -52,13 +53,19 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { repo := NewFileTokenRepository(baseDir) // 创建后台刷新器,配置参数 - m.refresher = NewBackgroundRefresher( - repo, - WithInterval(time.Minute), // 每分钟检查一次 - WithBatchSize(50), // 每批最多处理 50 个 token - WithConcurrency(10), // 最多 10 个并发刷新 - WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 - ) + opts := []RefresherOption{ + WithInterval(time.Minute), // 每分钟检查一次 + WithBatchSize(50), // 每批最多处理 50 个 token + WithConcurrency(10), // 最多 10 个并发刷新 + WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 + } + + // 如果已设置回调,传递给 BackgroundRefresher + if m.onTokenRefreshed != nil { + opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed)) + } + + m.refresher = NewBackgroundRefresher(repo, opts...) log.Infof("refresh manager: initialized with base directory %s", baseDir) return nil @@ -127,6 +134,25 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) { } } +// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数 +// 可以在任何时候调用,支持运行时更新回调 +// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据 +func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) { + m.mu.Lock() + defer m.mu.Unlock() + + m.onTokenRefreshed = callback + + // 如果 refresher 已经创建,使用并发安全的方式更新它的回调 + if m.refresher != nil { + m.refresher.callbackMu.Lock() + m.refresher.onTokenRefreshed = callback + m.refresher.callbackMu.Unlock() + } + + log.Debug("refresh manager: token refresh callback registered") +} + // InitializeAndStart 初始化并启动后台刷新(便捷方法) func InitializeAndStart(baseDir string, cfg *config.Config) { manager := GetRefreshManager() diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 4506601d..ed6014a2 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -581,18 +581,30 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Check if token is expired before making request if e.isTokenExpired(accessToken) { - log.Infof("kiro: access token expired, attempting refresh before request") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) - } else if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - } + log.Infof("kiro: access token expired, attempting recovery") + + // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) + reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) + if reloadErr == nil && reloadedAuth != nil { + // 文件中有更新的 token,使用它 + auth = reloadedAuth accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: token refreshed successfully before request") + log.Infof("kiro: recovered token from file (background refresh), expires_at: %v", auth.Metadata["expires_at"]) + } else { + // 文件中的 token 也过期了,执行主动刷新 + log.Debugf("kiro: file reload failed (%v), attempting active refresh", reloadErr) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before request") + } } } @@ -979,18 +991,30 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // Check if token is expired before making request if e.isTokenExpired(accessToken) { - log.Infof("kiro: access token expired, attempting refresh before stream request") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) - } else if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - } + log.Infof("kiro: access token expired, attempting recovery before stream request") + + // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) + reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) + if reloadErr == nil && reloadedAuth != nil { + // 文件中有更新的 token,使用它 + auth = reloadedAuth accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: token refreshed successfully before stream request") + log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"]) + } else { + // 文件中的 token 也过期了,执行主动刷新 + log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } } } @@ -3689,6 +3713,121 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { return nil } +// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制) +// 当内存中的 token 已过期时,尝试从文件读取最新的 token +// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题 +func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, fmt.Errorf("kiro executor: cannot reload nil auth") + } + + // 确定文件路径 + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload") + } + } + + // 读取文件 + raw, err := os.ReadFile(authPath) + if err != nil { + return nil, fmt.Errorf("kiro executor: failed to read auth file %s: %w", authPath, err) + } + + // 解析 JSON + var metadata map[string]any + if err := json.Unmarshal(raw, &metadata); err != nil { + return nil, fmt.Errorf("kiro executor: failed to parse auth file %s: %w", authPath, err) + } + + // 检查文件中的 token 是否比内存中的更新 + fileExpiresAt, _ := metadata["expires_at"].(string) + fileAccessToken, _ := metadata["access_token"].(string) + memExpiresAt, _ := auth.Metadata["expires_at"].(string) + memAccessToken, _ := auth.Metadata["access_token"].(string) + + // 文件中必须有有效的 access_token + if fileAccessToken == "" { + return nil, fmt.Errorf("kiro executor: auth file has no access_token field") + } + + // 如果有 expires_at,检查是否过期 + if fileExpiresAt != "" { + fileExpTime, parseErr := time.Parse(time.RFC3339, fileExpiresAt) + if parseErr == nil { + // 如果文件中的 token 也已过期,不使用它 + if time.Now().After(fileExpTime) { + log.Debugf("kiro executor: file token also expired at %s, not using", fileExpiresAt) + return nil, fmt.Errorf("kiro executor: file token also expired") + } + } + } + + // 判断文件中的 token 是否比内存中的更新 + // 条件1: access_token 不同(说明已刷新) + // 条件2: expires_at 更新(说明已刷新) + isNewer := false + + // 优先检查 access_token 是否变化 + if fileAccessToken != memAccessToken { + isNewer = true + log.Debugf("kiro executor: file access_token differs from memory, using file token") + } + + // 如果 access_token 相同,检查 expires_at + if !isNewer && fileExpiresAt != "" && memExpiresAt != "" { + fileExpTime, fileParseErr := time.Parse(time.RFC3339, fileExpiresAt) + memExpTime, memParseErr := time.Parse(time.RFC3339, memExpiresAt) + if fileParseErr == nil && memParseErr == nil && fileExpTime.After(memExpTime) { + isNewer = true + log.Debugf("kiro executor: file expires_at (%s) is newer than memory (%s)", fileExpiresAt, memExpiresAt) + } + } + + // 如果文件中没有 expires_at 但 access_token 相同,无法判断是否更新 + if !isNewer && fileExpiresAt == "" && fileAccessToken == memAccessToken { + return nil, fmt.Errorf("kiro executor: cannot determine if file token is newer (no expires_at, same access_token)") + } + + if !isNewer { + log.Debugf("kiro executor: file token not newer than memory token") + return nil, fmt.Errorf("kiro executor: file token not newer") + } + + // 创建更新后的 auth 对象 + updated := auth.Clone() + updated.Metadata = metadata + updated.UpdatedAt = time.Now() + + // 同步更新 Attributes + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + if accessToken, ok := metadata["access_token"].(string); ok { + updated.Attributes["access_token"] = accessToken + } + if profileArn, ok := metadata["profile_arn"].(string); ok { + updated.Attributes["profile_arn"] = profileArn + } + + log.Infof("kiro executor: reloaded auth from file %s, new expires_at: %s", authPath, fileExpiresAt) + return updated, nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 77006cf8..8141ca07 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -145,3 +145,111 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { w.clientsMutex.RUnlock() return snapshotCoreAuths(cfg, w.authDir) } + +// NotifyTokenRefreshed 处理后台刷新器的 token 更新通知 +// 当后台刷新器成功刷新 token 后调用此方法,更新内存中的 Auth 对象 +// tokenID: token 文件名(如 kiro-xxx.json) +// accessToken: 新的 access token +// refreshToken: 新的 refresh token +// expiresAt: 新的过期时间 +func (w *Watcher) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) { + if w == nil { + return + } + + w.clientsMutex.Lock() + defer w.clientsMutex.Unlock() + + // 遍历 currentAuths,找到匹配的 Auth 并更新 + updated := false + for id, auth := range w.currentAuths { + if auth == nil || auth.Metadata == nil { + continue + } + + // 检查是否是 kiro 类型的 auth + authType, _ := auth.Metadata["type"].(string) + if authType != "kiro" { + continue + } + + // 多种匹配方式,解决不同来源的 auth 对象字段差异 + matched := false + + // 1. 通过 auth.ID 匹配(ID 可能包含文件名) + if !matched && auth.ID != "" { + if auth.ID == tokenID || strings.HasSuffix(auth.ID, "/"+tokenID) || strings.HasSuffix(auth.ID, "\\"+tokenID) { + matched = true + } + // ID 可能是 "kiro-xxx" 格式(无扩展名),tokenID 是 "kiro-xxx.json" + if !matched && strings.TrimSuffix(tokenID, ".json") == auth.ID { + matched = true + } + } + + // 2. 通过 auth.Attributes["path"] 匹配 + if !matched && auth.Attributes != nil { + if authPath := auth.Attributes["path"]; authPath != "" { + // 提取文件名部分进行比较 + pathBase := authPath + if idx := strings.LastIndexAny(authPath, "/\\"); idx >= 0 { + pathBase = authPath[idx+1:] + } + if pathBase == tokenID || strings.TrimSuffix(pathBase, ".json") == strings.TrimSuffix(tokenID, ".json") { + matched = true + } + } + } + + // 3. 通过 auth.FileName 匹配(原有逻辑) + if !matched && auth.FileName != "" { + if auth.FileName == tokenID || strings.HasSuffix(auth.FileName, "/"+tokenID) || strings.HasSuffix(auth.FileName, "\\"+tokenID) { + matched = true + } + } + + if matched { + // 更新内存中的 token + auth.Metadata["access_token"] = accessToken + auth.Metadata["refresh_token"] = refreshToken + auth.Metadata["expires_at"] = expiresAt + auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) + auth.UpdatedAt = time.Now() + auth.LastRefreshedAt = time.Now() + + log.Infof("watcher: updated in-memory auth for token %s (auth ID: %s)", tokenID, id) + updated = true + + // 同时更新 runtimeAuths 中的副本(如果存在) + if w.runtimeAuths != nil { + if runtimeAuth, ok := w.runtimeAuths[id]; ok && runtimeAuth != nil { + if runtimeAuth.Metadata == nil { + runtimeAuth.Metadata = make(map[string]any) + } + runtimeAuth.Metadata["access_token"] = accessToken + runtimeAuth.Metadata["refresh_token"] = refreshToken + runtimeAuth.Metadata["expires_at"] = expiresAt + runtimeAuth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) + runtimeAuth.UpdatedAt = time.Now() + runtimeAuth.LastRefreshedAt = time.Now() + } + } + + // 发送更新通知到 authQueue + if w.authQueue != nil { + go func(authClone *coreauth.Auth) { + update := AuthUpdate{ + Action: AuthUpdateActionModify, + ID: authClone.ID, + Auth: authClone, + } + w.dispatchAuthUpdates([]AuthUpdate{update}) + }(auth.Clone()) + } + } + } + + if !updated { + log.Debugf("watcher: no matching auth found for token %s, will be picked up on next file scan", tokenID) + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 885304ad..750eb885 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -98,6 +98,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) { usage.RegisterPlugin(plugin) } +// GetWatcher returns the underlying WatcherWrapper instance. +// This allows external components (e.g., RefreshManager) to interact with the watcher. +// Returns nil if the service or watcher is not initialized. +func (s *Service) GetWatcher() *WatcherWrapper { + if s == nil { + return nil + } + return s.watcher +} + // newDefaultAuthManager creates a default authentication manager with all supported providers. func newDefaultAuthManager() *sdkAuth.Manager { return sdkAuth.NewManager( @@ -575,6 +585,18 @@ func (s *Service) Run(ctx context.Context) error { } watcherWrapper.SetConfig(s.cfg) + // 方案 A: 连接 Kiro 后台刷新器回调到 Watcher + // 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象 + // 这解决了后台刷新与内存 Auth 对象之间的时间差问题 + kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) { + if tokenData == nil || watcherWrapper == nil { + return + } + log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID) + watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt) + }) + log.Debug("kiro: connected background refresh callback to watcher") + watcherCtx, watcherCancel := context.WithCancel(context.Background()) s.watcherCancel = watcherCancel if err = watcherWrapper.Start(watcherCtx); err != nil { diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 1521dffe..ee8f761d 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -89,6 +89,7 @@ type WatcherWrapper struct { snapshotAuths func() []*coreauth.Auth setUpdateQueue func(queue chan<- watcher.AuthUpdate) dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool + notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知 } // Start proxies to the underlying watcher Start implementation. @@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) { } w.setUpdateQueue(queue) } + +// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token +// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题 +// tokenID: token 文件名(如 kiro-xxx.json) +// accessToken: 新的 access token +// refreshToken: 新的 refresh token +// expiresAt: 新的过期时间(RFC3339 格式) +func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) { + if w == nil || w.notifyTokenRefreshed == nil { + return + } + w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt) +} diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go index caeadf19..e6e91bdd 100644 --- a/sdk/cliproxy/watcher.go +++ b/sdk/cliproxy/watcher.go @@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool { return w.DispatchRuntimeAuthUpdate(update) }, + notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) { + w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt) + }, }, nil }