mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-06-01 23:19:24 +08:00
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -533,6 +534,13 @@ func main() {
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
// 初始化并启动 Kiro token 后台刷新
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,14 +50,16 @@ func WithConcurrency(concurrency int) RefresherOption {
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
callbackMu sync.RWMutex // 保护回调函数的并发访问
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
@@ -84,6 +86,17 @@ func WithConfig(cfg *config.Config) RefresherOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
|
||||
// The callback receives the token ID (filename) and the new token data.
|
||||
// This allows external components (e.g., Watcher) to be notified of token updates.
|
||||
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.callbackMu.Lock()
|
||||
r.onTokenRefreshed = callback
|
||||
r.callbackMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
@@ -188,5 +201,24 @@ func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
|
||||
r.callbackMu.RLock()
|
||||
callback := r.onTokenRefreshed
|
||||
r.callbackMu.RUnlock()
|
||||
|
||||
if callback != nil {
|
||||
// 使用 defer recover 隔离回调 panic,防止崩溃整个进程
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
|
||||
}
|
||||
}()
|
||||
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
|
||||
callback(token.ID, newTokenData)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
171
internal/auth/kiro/refresh_manager.go
Normal file
171
internal/auth/kiro/refresh_manager.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RefreshManager 是后台刷新器的单例管理器
|
||||
type RefreshManager struct {
|
||||
mu sync.Mutex
|
||||
refresher *BackgroundRefresher
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
var (
|
||||
globalRefreshManager *RefreshManager
|
||||
managerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetRefreshManager 获取全局刷新管理器实例
|
||||
func GetRefreshManager() *RefreshManager {
|
||||
managerOnce.Do(func() {
|
||||
globalRefreshManager = &RefreshManager{}
|
||||
})
|
||||
return globalRefreshManager
|
||||
}
|
||||
|
||||
// Initialize 初始化后台刷新器
|
||||
// baseDir: token 文件所在的目录
|
||||
// cfg: 应用配置
|
||||
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
if baseDir == "" {
|
||||
log.Warn("refresh manager: base directory not provided, skipping initialization")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建 token 存储库
|
||||
repo := NewFileTokenRepository(baseDir)
|
||||
|
||||
// 创建后台刷新器,配置参数
|
||||
opts := []RefresherOption{
|
||||
WithInterval(time.Minute), // 每分钟检查一次
|
||||
WithBatchSize(50), // 每批最多处理 50 个 token
|
||||
WithConcurrency(10), // 最多 10 个并发刷新
|
||||
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
|
||||
}
|
||||
|
||||
// 如果已设置回调,传递给 BackgroundRefresher
|
||||
if m.onTokenRefreshed != nil {
|
||||
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||
}
|
||||
|
||||
m.refresher = NewBackgroundRefresher(repo, opts...)
|
||||
|
||||
log.Infof("refresh manager: initialized with base directory %s", baseDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动后台刷新
|
||||
func (m *RefreshManager) Start() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already started")
|
||||
return
|
||||
}
|
||||
|
||||
if m.refresher == nil {
|
||||
log.Warn("refresh manager: not initialized, cannot start")
|
||||
return
|
||||
}
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
m.refresher.Start(m.ctx)
|
||||
m.started = true
|
||||
|
||||
log.Info("refresh manager: background refresh started")
|
||||
}
|
||||
|
||||
// Stop 停止后台刷新
|
||||
func (m *RefreshManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.started {
|
||||
return
|
||||
}
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.refresher != nil {
|
||||
m.refresher.Stop()
|
||||
}
|
||||
|
||||
m.started = false
|
||||
log.Info("refresh manager: background refresh stopped")
|
||||
}
|
||||
|
||||
// IsRunning 检查后台刷新是否正在运行
|
||||
func (m *RefreshManager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.started
|
||||
}
|
||||
|
||||
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
|
||||
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.refresher != nil && m.refresher.tokenRepo != nil {
|
||||
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
|
||||
repo.SetBaseDir(baseDir)
|
||||
log.Infof("refresh manager: updated base directory to %s", baseDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
|
||||
// 可以在任何时候调用,支持运行时更新回调
|
||||
// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
|
||||
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onTokenRefreshed = callback
|
||||
|
||||
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
|
||||
if m.refresher != nil {
|
||||
m.refresher.callbackMu.Lock()
|
||||
m.refresher.onTokenRefreshed = callback
|
||||
m.refresher.callbackMu.Unlock()
|
||||
}
|
||||
|
||||
log.Debug("refresh manager: token refresh callback registered")
|
||||
}
|
||||
|
||||
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
|
||||
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||
manager := GetRefreshManager()
|
||||
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||
return
|
||||
}
|
||||
manager.Start()
|
||||
}
|
||||
|
||||
// StopGlobalRefreshManager 停止全局刷新管理器
|
||||
func StopGlobalRefreshManager() {
|
||||
if globalRefreshManager != nil {
|
||||
globalRefreshManager.Stop()
|
||||
}
|
||||
}
|
||||
@@ -26,13 +26,13 @@ type KiroTokenStorage struct {
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ClientID is the OAuth client ID (required for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
// Region is the AWS region
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
// Email is the user's email address
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
273
internal/auth/kiro/token_repository.go
Normal file
273
internal/auth/kiro/token_repository.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
|
||||
type FileTokenRepository struct {
|
||||
mu sync.RWMutex
|
||||
baseDir string
|
||||
}
|
||||
|
||||
// NewFileTokenRepository 创建一个新的文件 token 存储库
|
||||
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
|
||||
return &FileTokenRepository{
|
||||
baseDir: baseDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBaseDir 设置基础目录
|
||||
func (r *FileTokenRepository) SetBaseDir(dir string) {
|
||||
r.mu.Lock()
|
||||
r.baseDir = strings.TrimSpace(dir)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序)
|
||||
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
log.Debug("token repository: base directory not configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil // 忽略错误,继续遍历
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理 kiro 相关的 token 文件
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
log.Debugf("token repository: failed to read token file %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if token != nil && token.RefreshToken != "" {
|
||||
// 检查 token 是否需要刷新(过期前 5 分钟)
|
||||
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("token repository: error walking directory: %v", err)
|
||||
}
|
||||
|
||||
// 按最后验证时间排序(最旧的优先)
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
|
||||
})
|
||||
|
||||
// 限制返回数量
|
||||
if limit > 0 && len(tokens) > limit {
|
||||
tokens = tokens[:limit]
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// UpdateToken 更新 token 并持久化到文件
|
||||
func (r *FileTokenRepository) UpdateToken(token *Token) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token repository: token is nil")
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
// 构建文件路径
|
||||
filePath := filepath.Join(baseDir, token.ID)
|
||||
if !strings.HasSuffix(filePath, ".json") {
|
||||
filePath += ".json"
|
||||
}
|
||||
|
||||
// 读取现有文件内容
|
||||
existingData := make(map[string]any)
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
_ = json.Unmarshal(data, &existingData)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
existingData["access_token"] = token.AccessToken
|
||||
existingData["refresh_token"] = token.RefreshToken
|
||||
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
if !token.ExpiresAt.IsZero() {
|
||||
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 保持原有的关键字段
|
||||
if token.ClientID != "" {
|
||||
existingData["client_id"] = token.ClientID
|
||||
}
|
||||
if token.ClientSecret != "" {
|
||||
existingData["client_secret"] = token.ClientSecret
|
||||
}
|
||||
if token.AuthMethod != "" {
|
||||
existingData["auth_method"] = token.AuthMethod
|
||||
}
|
||||
if token.Region != "" {
|
||||
existingData["region"] = token.Region
|
||||
}
|
||||
if token.StartURL != "" {
|
||||
existingData["start_url"] = token.StartURL
|
||||
}
|
||||
|
||||
// 序列化并写入文件
|
||||
raw, err := json.MarshalIndent(existingData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("token repository: marshal failed: %w", err)
|
||||
}
|
||||
|
||||
// 原子写入:先写入临时文件,再重命名
|
||||
tmpPath := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
|
||||
return fmt.Errorf("token repository: write temp file failed: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, filePath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("token repository: rename failed: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("token repository: updated token %s", token.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// readTokenFile 从文件读取 token
|
||||
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否是 kiro token
|
||||
tokenType, _ := metadata["type"].(string)
|
||||
if tokenType != "kiro" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 检查 auth_method
|
||||
authMethod, _ := metadata["auth_method"].(string)
|
||||
if authMethod != "idc" && authMethod != "builder-id" {
|
||||
return nil, nil // 只处理 IDC 和 Builder ID token
|
||||
}
|
||||
|
||||
token := &Token{
|
||||
ID: filepath.Base(path),
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
|
||||
// 解析各字段
|
||||
if v, ok := metadata["access_token"].(string); ok {
|
||||
token.AccessToken = v
|
||||
}
|
||||
if v, ok := metadata["refresh_token"].(string); ok {
|
||||
token.RefreshToken = v
|
||||
}
|
||||
if v, ok := metadata["client_id"].(string); ok {
|
||||
token.ClientID = v
|
||||
}
|
||||
if v, ok := metadata["client_secret"].(string); ok {
|
||||
token.ClientSecret = v
|
||||
}
|
||||
if v, ok := metadata["region"].(string); ok {
|
||||
token.Region = v
|
||||
}
|
||||
if v, ok := metadata["start_url"].(string); ok {
|
||||
token.StartURL = v
|
||||
}
|
||||
if v, ok := metadata["provider"].(string); ok {
|
||||
token.Provider = v
|
||||
}
|
||||
|
||||
// 解析时间字段
|
||||
if v, ok := metadata["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
token.ExpiresAt = t
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["last_refresh"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
token.LastVerified = t
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokens 列出所有 Kiro token(用于调试)
|
||||
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return nil, fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if token != nil {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return tokens, err
|
||||
}
|
||||
@@ -581,18 +581,30 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
// Check if token is expired before making request
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting refresh before request")
|
||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||
if refreshErr != nil {
|
||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||
} else if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
// Persist the refreshed auth to file so subsequent requests use it
|
||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||
}
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
// 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件)
|
||||
reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
|
||||
if reloadErr == nil && reloadedAuth != nil {
|
||||
// 文件中有更新的 token,使用它
|
||||
auth = reloadedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
log.Infof("kiro: token refreshed successfully before request")
|
||||
log.Infof("kiro: recovered token from file (background refresh), expires_at: %v", auth.Metadata["expires_at"])
|
||||
} else {
|
||||
// 文件中的 token 也过期了,执行主动刷新
|
||||
log.Debugf("kiro: file reload failed (%v), attempting active refresh", reloadErr)
|
||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||
if refreshErr != nil {
|
||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||
} else if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
// Persist the refreshed auth to file so subsequent requests use it
|
||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||
}
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
log.Infof("kiro: token refreshed successfully before request")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -979,18 +991,30 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
// Check if token is expired before making request
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting refresh before stream request")
|
||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||
if refreshErr != nil {
|
||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||
} else if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
// Persist the refreshed auth to file so subsequent requests use it
|
||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||
}
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
// 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件)
|
||||
reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
|
||||
if reloadErr == nil && reloadedAuth != nil {
|
||||
// 文件中有更新的 token,使用它
|
||||
auth = reloadedAuth
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
log.Infof("kiro: token refreshed successfully before stream request")
|
||||
log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"])
|
||||
} else {
|
||||
// 文件中的 token 也过期了,执行主动刷新
|
||||
log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr)
|
||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||
if refreshErr != nil {
|
||||
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
|
||||
} else if refreshedAuth != nil {
|
||||
auth = refreshedAuth
|
||||
// Persist the refreshed auth to file so subsequent requests use it
|
||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||
}
|
||||
accessToken, profileArn = kiroCredentials(auth)
|
||||
log.Infof("kiro: token refreshed successfully before stream request")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3617,6 +3641,13 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
||||
if tokenData.ClientSecret != "" {
|
||||
updated.Metadata["client_secret"] = tokenData.ClientSecret
|
||||
}
|
||||
// Preserve region and start_url for IDC token refresh
|
||||
if tokenData.Region != "" {
|
||||
updated.Metadata["region"] = tokenData.Region
|
||||
}
|
||||
if tokenData.StartURL != "" {
|
||||
updated.Metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
|
||||
if updated.Attributes == nil {
|
||||
updated.Attributes = make(map[string]string)
|
||||
@@ -3682,6 +3713,121 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
|
||||
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
|
||||
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
|
||||
func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("kiro executor: cannot reload nil auth")
|
||||
}
|
||||
|
||||
// 确定文件路径
|
||||
var authPath string
|
||||
if auth.Attributes != nil {
|
||||
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
|
||||
authPath = p
|
||||
}
|
||||
}
|
||||
if authPath == "" {
|
||||
fileName := strings.TrimSpace(auth.FileName)
|
||||
if fileName == "" {
|
||||
return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload")
|
||||
}
|
||||
if filepath.IsAbs(fileName) {
|
||||
authPath = fileName
|
||||
} else if e.cfg != nil && e.cfg.AuthDir != "" {
|
||||
authPath = filepath.Join(e.cfg.AuthDir, fileName)
|
||||
} else {
|
||||
return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload")
|
||||
}
|
||||
}
|
||||
|
||||
// 读取文件
|
||||
raw, err := os.ReadFile(authPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kiro executor: failed to read auth file %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(raw, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("kiro executor: failed to parse auth file %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
// 检查文件中的 token 是否比内存中的更新
|
||||
fileExpiresAt, _ := metadata["expires_at"].(string)
|
||||
fileAccessToken, _ := metadata["access_token"].(string)
|
||||
memExpiresAt, _ := auth.Metadata["expires_at"].(string)
|
||||
memAccessToken, _ := auth.Metadata["access_token"].(string)
|
||||
|
||||
// 文件中必须有有效的 access_token
|
||||
if fileAccessToken == "" {
|
||||
return nil, fmt.Errorf("kiro executor: auth file has no access_token field")
|
||||
}
|
||||
|
||||
// 如果有 expires_at,检查是否过期
|
||||
if fileExpiresAt != "" {
|
||||
fileExpTime, parseErr := time.Parse(time.RFC3339, fileExpiresAt)
|
||||
if parseErr == nil {
|
||||
// 如果文件中的 token 也已过期,不使用它
|
||||
if time.Now().After(fileExpTime) {
|
||||
log.Debugf("kiro executor: file token also expired at %s, not using", fileExpiresAt)
|
||||
return nil, fmt.Errorf("kiro executor: file token also expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断文件中的 token 是否比内存中的更新
|
||||
// 条件1: access_token 不同(说明已刷新)
|
||||
// 条件2: expires_at 更新(说明已刷新)
|
||||
isNewer := false
|
||||
|
||||
// 优先检查 access_token 是否变化
|
||||
if fileAccessToken != memAccessToken {
|
||||
isNewer = true
|
||||
log.Debugf("kiro executor: file access_token differs from memory, using file token")
|
||||
}
|
||||
|
||||
// 如果 access_token 相同,检查 expires_at
|
||||
if !isNewer && fileExpiresAt != "" && memExpiresAt != "" {
|
||||
fileExpTime, fileParseErr := time.Parse(time.RFC3339, fileExpiresAt)
|
||||
memExpTime, memParseErr := time.Parse(time.RFC3339, memExpiresAt)
|
||||
if fileParseErr == nil && memParseErr == nil && fileExpTime.After(memExpTime) {
|
||||
isNewer = true
|
||||
log.Debugf("kiro executor: file expires_at (%s) is newer than memory (%s)", fileExpiresAt, memExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果文件中没有 expires_at 但 access_token 相同,无法判断是否更新
|
||||
if !isNewer && fileExpiresAt == "" && fileAccessToken == memAccessToken {
|
||||
return nil, fmt.Errorf("kiro executor: cannot determine if file token is newer (no expires_at, same access_token)")
|
||||
}
|
||||
|
||||
if !isNewer {
|
||||
log.Debugf("kiro executor: file token not newer than memory token")
|
||||
return nil, fmt.Errorf("kiro executor: file token not newer")
|
||||
}
|
||||
|
||||
// 创建更新后的 auth 对象
|
||||
updated := auth.Clone()
|
||||
updated.Metadata = metadata
|
||||
updated.UpdatedAt = time.Now()
|
||||
|
||||
// 同步更新 Attributes
|
||||
if updated.Attributes == nil {
|
||||
updated.Attributes = make(map[string]string)
|
||||
}
|
||||
if accessToken, ok := metadata["access_token"].(string); ok {
|
||||
updated.Attributes["access_token"] = accessToken
|
||||
}
|
||||
if profileArn, ok := metadata["profile_arn"].(string); ok {
|
||||
updated.Attributes["profile_arn"] = profileArn
|
||||
}
|
||||
|
||||
log.Infof("kiro executor: reloaded auth from file %s, new expires_at: %s", authPath, fileExpiresAt)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// isTokenExpired checks if a JWT access token has expired.
|
||||
// Returns true if the token is expired or cannot be parsed.
|
||||
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
|
||||
@@ -145,3 +145,111 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
w.clientsMutex.RUnlock()
|
||||
return snapshotCoreAuths(cfg, w.authDir)
|
||||
}
|
||||
|
||||
// NotifyTokenRefreshed 处理后台刷新器的 token 更新通知
|
||||
// 当后台刷新器成功刷新 token 后调用此方法,更新内存中的 Auth 对象
|
||||
// tokenID: token 文件名(如 kiro-xxx.json)
|
||||
// accessToken: 新的 access token
|
||||
// refreshToken: 新的 refresh token
|
||||
// expiresAt: 新的过期时间
|
||||
func (w *Watcher) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
|
||||
// 遍历 currentAuths,找到匹配的 Auth 并更新
|
||||
updated := false
|
||||
for id, auth := range w.currentAuths {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是 kiro 类型的 auth
|
||||
authType, _ := auth.Metadata["type"].(string)
|
||||
if authType != "kiro" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 多种匹配方式,解决不同来源的 auth 对象字段差异
|
||||
matched := false
|
||||
|
||||
// 1. 通过 auth.ID 匹配(ID 可能包含文件名)
|
||||
if !matched && auth.ID != "" {
|
||||
if auth.ID == tokenID || strings.HasSuffix(auth.ID, "/"+tokenID) || strings.HasSuffix(auth.ID, "\\"+tokenID) {
|
||||
matched = true
|
||||
}
|
||||
// ID 可能是 "kiro-xxx" 格式(无扩展名),tokenID 是 "kiro-xxx.json"
|
||||
if !matched && strings.TrimSuffix(tokenID, ".json") == auth.ID {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 通过 auth.Attributes["path"] 匹配
|
||||
if !matched && auth.Attributes != nil {
|
||||
if authPath := auth.Attributes["path"]; authPath != "" {
|
||||
// 提取文件名部分进行比较
|
||||
pathBase := authPath
|
||||
if idx := strings.LastIndexAny(authPath, "/\\"); idx >= 0 {
|
||||
pathBase = authPath[idx+1:]
|
||||
}
|
||||
if pathBase == tokenID || strings.TrimSuffix(pathBase, ".json") == strings.TrimSuffix(tokenID, ".json") {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 通过 auth.FileName 匹配(原有逻辑)
|
||||
if !matched && auth.FileName != "" {
|
||||
if auth.FileName == tokenID || strings.HasSuffix(auth.FileName, "/"+tokenID) || strings.HasSuffix(auth.FileName, "\\"+tokenID) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
// 更新内存中的 token
|
||||
auth.Metadata["access_token"] = accessToken
|
||||
auth.Metadata["refresh_token"] = refreshToken
|
||||
auth.Metadata["expires_at"] = expiresAt
|
||||
auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
auth.UpdatedAt = time.Now()
|
||||
auth.LastRefreshedAt = time.Now()
|
||||
|
||||
log.Infof("watcher: updated in-memory auth for token %s (auth ID: %s)", tokenID, id)
|
||||
updated = true
|
||||
|
||||
// 同时更新 runtimeAuths 中的副本(如果存在)
|
||||
if w.runtimeAuths != nil {
|
||||
if runtimeAuth, ok := w.runtimeAuths[id]; ok && runtimeAuth != nil {
|
||||
if runtimeAuth.Metadata == nil {
|
||||
runtimeAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
runtimeAuth.Metadata["access_token"] = accessToken
|
||||
runtimeAuth.Metadata["refresh_token"] = refreshToken
|
||||
runtimeAuth.Metadata["expires_at"] = expiresAt
|
||||
runtimeAuth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
runtimeAuth.UpdatedAt = time.Now()
|
||||
runtimeAuth.LastRefreshedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// 发送更新通知到 authQueue
|
||||
if w.authQueue != nil {
|
||||
go func(authClone *coreauth.Auth) {
|
||||
update := AuthUpdate{
|
||||
Action: AuthUpdateActionModify,
|
||||
ID: authClone.ID,
|
||||
Auth: authClone,
|
||||
}
|
||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||
}(auth.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !updated {
|
||||
log.Debugf("watcher: no matching auth found for token %s, will be picked up on next file scan", tokenID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) {
|
||||
usage.RegisterPlugin(plugin)
|
||||
}
|
||||
|
||||
// GetWatcher returns the underlying WatcherWrapper instance.
|
||||
// This allows external components (e.g., RefreshManager) to interact with the watcher.
|
||||
// Returns nil if the service or watcher is not initialized.
|
||||
func (s *Service) GetWatcher() *WatcherWrapper {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.watcher
|
||||
}
|
||||
|
||||
// newDefaultAuthManager creates a default authentication manager with all supported providers.
|
||||
func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
return sdkAuth.NewManager(
|
||||
@@ -575,6 +585,18 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}
|
||||
watcherWrapper.SetConfig(s.cfg)
|
||||
|
||||
// 方案 A: 连接 Kiro 后台刷新器回调到 Watcher
|
||||
// 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象
|
||||
// 这解决了后台刷新与内存 Auth 对象之间的时间差问题
|
||||
kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) {
|
||||
if tokenData == nil || watcherWrapper == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID)
|
||||
watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt)
|
||||
})
|
||||
log.Debug("kiro: connected background refresh callback to watcher")
|
||||
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
s.watcherCancel = watcherCancel
|
||||
if err = watcherWrapper.Start(watcherCtx); err != nil {
|
||||
|
||||
@@ -89,6 +89,7 @@ type WatcherWrapper struct {
|
||||
snapshotAuths func() []*coreauth.Auth
|
||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
|
||||
notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知
|
||||
}
|
||||
|
||||
// Start proxies to the underlying watcher Start implementation.
|
||||
@@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) {
|
||||
}
|
||||
w.setUpdateQueue(queue)
|
||||
}
|
||||
|
||||
// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token
|
||||
// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题
|
||||
// tokenID: token 文件名(如 kiro-xxx.json)
|
||||
// accessToken: 新的 access token
|
||||
// refreshToken: 新的 refresh token
|
||||
// expiresAt: 新的过期时间(RFC3339 格式)
|
||||
func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
if w == nil || w.notifyTokenRefreshed == nil {
|
||||
return
|
||||
}
|
||||
w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
}
|
||||
|
||||
@@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
|
||||
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
|
||||
return w.DispatchRuntimeAuthUpdate(update)
|
||||
},
|
||||
notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user