diff --git a/cmd/protocheck/main.go b/cmd/protocheck/main.go index af68bbc6..ffae9d2b 100644 --- a/cmd/protocheck/main.go +++ b/cmd/protocheck/main.go @@ -7,13 +7,13 @@ import ( func main() { ecm := cursorproto.NewMsg("ExecClientMessage") - + // Try different field names names := []string{ "mcp_result", "mcpResult", "McpResult", "MCP_RESULT", "shell_result", "shellResult", } - + for _, name := range names { fd := ecm.Descriptor().Fields().ByName(name) if fd != nil { @@ -22,7 +22,7 @@ func main() { fmt.Printf("Field %q NOT FOUND\n", name) } } - + // List all fields fmt.Println("\nAll fields in ExecClientMessage:") for i := 0; i < ecm.Descriptor().Fields().Len(); i++ { diff --git a/cmd/server/main.go b/cmd/server/main.go index 47367946..0e2157b1 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -75,7 +75,6 @@ func main() { var codexLogin bool var codexDeviceLogin bool var claudeLogin bool - var qwenLogin bool var kiloLogin bool var iflowLogin bool var iflowCookie bool @@ -113,7 +112,6 @@ func main() { flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") - flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth") flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow") flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") @@ -538,8 +536,6 @@ func main() { } else if claudeLogin { // Handle Claude login cmd.DoClaudeLogin(cfg, options) - } else if qwenLogin { - cmd.DoQwenLogin(cfg, options) } else if kiloLogin { cmd.DoKiloLogin(cfg, options) } else if iflowLogin { diff --git a/config.example.yaml b/config.example.yaml index 55990416..a2e9d03b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -95,6 +95,10 @@ max-retry-interval: 30 # When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states). disable-cooling: false +# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh). +# When > 0, overrides the default worker count (16). +# auth-auto-refresh-workers: 16 + # Quota exceeded behavior quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded @@ -103,7 +107,14 @@ quota-exceeded: # Routing strategy for selecting credentials when multiple match. routing: - strategy: 'round-robin' # round-robin (default), fill-first + strategy: "round-robin" # round-robin (default), fill-first + # Enable universal session-sticky routing for all clients. + # Session IDs are extracted from: X-Session-ID header, Idempotency-Key, + # metadata.user_id, conversation_id, or first few messages hash. + # Automatic failover is always enabled when bound auth becomes unavailable. + session-affinity: false # default: false + # How long session-to-auth bindings are retained. Default: 1h + session-affinity-ttl: "1h" # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false @@ -269,7 +280,7 @@ nonstream-keepalive-interval: 0 # # Requests to that alias will round-robin across the upstream names below, # # and if the chosen upstream fails before producing output, the request will # # continue with the next upstream model in the same alias pool. -# - name: "qwen3.5-plus" +# - name: "deepseek-v3.1" # alias: "claude-opus-4.66" # - name: "glm-5" # alias: "claude-opus-4.66" @@ -330,7 +341,7 @@ nonstream-keepalive-interval: 0 # Global OAuth model name aliases (per channel) # These aliases rename model IDs for both model listing and request routing. -# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi. +# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi. # NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. # NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping # client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps @@ -369,9 +380,6 @@ nonstream-keepalive-interval: 0 # codex: # - name: "gpt-5" # alias: "g5" -# qwen: -# - name: "qwen3-coder-plus" -# alias: "qwen-plus" # iflow: # - name: "glm-4.7" # alias: "glm-god" @@ -403,8 +411,6 @@ nonstream-keepalive-interval: 0 # - "claude-3-5-haiku-20241022" # codex: # - "gpt-5-codex-mini" -# qwen: -# - "vision-model" # iflow: # - "tstars2.0" # kimi: diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 3c23cd78..4078db84 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,7 +36,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi" kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" @@ -2526,62 +2525,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } -func (h *Handler) RequestQwenToken(c *gin.Context) { - ctx := context.Background() - ctx = PopulateAuthContext(ctx, c) - - fmt.Println("Initializing Qwen authentication...") - - state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) - // Initialize Qwen auth service - qwenAuth := qwen.NewQwenAuth(h.cfg) - - // Generate authorization URL - deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx) - if err != nil { - log.Errorf("Failed to generate authorization URL: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) - return - } - authURL := deviceFlow.VerificationURIComplete - - RegisterOAuthSession(state, "qwen") - - go func() { - fmt.Println("Waiting for authentication...") - tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) - if errPollForToken != nil { - SetOAuthSessionError(state, "Authentication failed") - fmt.Printf("Authentication failed: %v\n", errPollForToken) - return - } - - // Create token storage - tokenStorage := qwenAuth.CreateTokenStorage(tokenData) - - tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli()) - record := &coreauth.Auth{ - ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email), - Provider: "qwen", - FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email), - Storage: tokenStorage, - Metadata: map[string]any{"email": tokenStorage.Email}, - } - savedPath, errSave := h.saveTokenRecord(ctx, record) - if errSave != nil { - log.Errorf("Failed to save authentication tokens: %v", errSave) - SetOAuthSessionError(state, "Failed to save authentication tokens") - return - } - - fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) - fmt.Println("You can now use Qwen services through this CLI") - CompleteOAuthSession(state) - }() - - c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) -} - func (h *Handler) RequestKimiToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index dfcdae88..156eeb79 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -236,8 +236,6 @@ func NormalizeOAuthProvider(provider string) (string, error) { return "iflow", nil case "antigravity", "anti-gravity": return "antigravity", nil - case "qwen": - return "qwen", nil case "kiro": return "kiro", nil case "github": diff --git a/internal/api/server.go b/internal/api/server.go index 690e4560..6dd6916d 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -24,8 +24,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp" - "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" @@ -684,7 +684,6 @@ func (s *Server) registerManagementRoutes() { mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken) mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) - mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken) mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) @@ -1122,20 +1121,17 @@ func applySignatureCacheConfig(oldCfg, cfg *config.Config) { if oldCfg == nil { cache.SetSignatureCacheEnabled(newVal) cache.SetSignatureBypassStrictMode(newStrict) - log.Debugf("antigravity_signature_cache_enabled toggled to %t", newVal) return } oldVal := configuredSignatureCacheEnabled(oldCfg) if oldVal != newVal { cache.SetSignatureCacheEnabled(newVal) - log.Debugf("antigravity_signature_cache_enabled updated from %t to %t", oldVal, newVal) } oldStrict := configuredSignatureBypassStrict(oldCfg) if oldStrict != newStrict { cache.SetSignatureBypassStrictMode(newStrict) - log.Debugf("antigravity_signature_bypass_strict updated from %t to %t", oldStrict, newStrict) } } diff --git a/internal/auth/codebuddy/codebuddy_auth.go b/internal/auth/codebuddy/codebuddy_auth.go index ce0b803a..b4384200 100644 --- a/internal/auth/codebuddy/codebuddy_auth.go +++ b/internal/auth/codebuddy/codebuddy_auth.go @@ -63,7 +63,7 @@ func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err) } -requestID := uuid.NewString() + requestID := uuid.NewString() req.Header.Set("Accept", "application/json, text/plain, */*") req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Requested-With", "XMLHttpRequest") diff --git a/internal/auth/codebuddy/codebuddy_auth_test.go b/internal/auth/codebuddy/codebuddy_auth_test.go index f4ff553f..b43be391 100644 --- a/internal/auth/codebuddy/codebuddy_auth_test.go +++ b/internal/auth/codebuddy/codebuddy_auth_test.go @@ -19,4 +19,3 @@ func TestDecodeUserID_ValidJWT(t *testing.T) { t.Errorf("expected 'test-user-id-123', got '%s'", userID) } } - diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go index 663b6de1..144d659c 100644 --- a/internal/auth/copilot/copilot_auth.go +++ b/internal/auth/copilot/copilot_auth.go @@ -24,11 +24,11 @@ const ( copilotAPIEndpoint = "https://api.githubcopilot.com" // Common HTTP header values for Copilot API requests. - copilotUserAgent = "GithubCopilot/1.0" - copilotEditorVersion = "vscode/1.100.0" - copilotPluginVersion = "copilot/1.300.0" - copilotIntegrationID = "vscode-chat" - copilotOpenAIIntent = "conversation-panel" + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" ) // CopilotAPIToken represents the Copilot API token response. @@ -314,9 +314,9 @@ const maxModelsResponseSize = 2 * 1024 * 1024 // allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests. var allowedCopilotAPIHosts = map[string]bool{ - "api.githubcopilot.com": true, - "api.individual.githubcopilot.com": true, - "api.business.githubcopilot.com": true, + "api.githubcopilot.com": true, + "api.individual.githubcopilot.com": true, + "api.business.githubcopilot.com": true, "copilot-proxy.githubusercontent.com": true, } diff --git a/internal/auth/cursor/proto/decode.go b/internal/auth/cursor/proto/decode.go index b3753a59..f54fc735 100644 --- a/internal/auth/cursor/proto/decode.go +++ b/internal/auth/cursor/proto/decode.go @@ -12,30 +12,30 @@ import ( type ServerMessageType int const ( - ServerMsgUnknown ServerMessageType = iota - ServerMsgTextDelta // Text content delta - ServerMsgThinkingDelta // Thinking/reasoning delta - ServerMsgThinkingCompleted // Thinking completed - ServerMsgKvGetBlob // Server wants a blob - ServerMsgKvSetBlob // Server wants to store a blob - ServerMsgExecRequestCtx // Server requests context (tools, etc.) - ServerMsgExecMcpArgs // Server wants MCP tool execution - ServerMsgExecShellArgs // Rejected: shell command - ServerMsgExecReadArgs // Rejected: file read - ServerMsgExecWriteArgs // Rejected: file write - ServerMsgExecDeleteArgs // Rejected: file delete - ServerMsgExecLsArgs // Rejected: directory listing - ServerMsgExecGrepArgs // Rejected: grep search - ServerMsgExecFetchArgs // Rejected: HTTP fetch - ServerMsgExecDiagnostics // Respond with empty diagnostics - ServerMsgExecShellStream // Rejected: shell stream - ServerMsgExecBgShellSpawn // Rejected: background shell - ServerMsgExecWriteShellStdin // Rejected: write shell stdin - ServerMsgExecOther // Other exec types (respond with empty) - ServerMsgTurnEnded // Turn has ended (no more output) - ServerMsgHeartbeat // Server heartbeat - ServerMsgTokenDelta // Token usage delta - ServerMsgCheckpoint // Conversation checkpoint update + ServerMsgUnknown ServerMessageType = iota + ServerMsgTextDelta // Text content delta + ServerMsgThinkingDelta // Thinking/reasoning delta + ServerMsgThinkingCompleted // Thinking completed + ServerMsgKvGetBlob // Server wants a blob + ServerMsgKvSetBlob // Server wants to store a blob + ServerMsgExecRequestCtx // Server requests context (tools, etc.) + ServerMsgExecMcpArgs // Server wants MCP tool execution + ServerMsgExecShellArgs // Rejected: shell command + ServerMsgExecReadArgs // Rejected: file read + ServerMsgExecWriteArgs // Rejected: file write + ServerMsgExecDeleteArgs // Rejected: file delete + ServerMsgExecLsArgs // Rejected: directory listing + ServerMsgExecGrepArgs // Rejected: grep search + ServerMsgExecFetchArgs // Rejected: HTTP fetch + ServerMsgExecDiagnostics // Respond with empty diagnostics + ServerMsgExecShellStream // Rejected: shell stream + ServerMsgExecBgShellSpawn // Rejected: background shell + ServerMsgExecWriteShellStdin // Rejected: write shell stdin + ServerMsgExecOther // Other exec types (respond with empty) + ServerMsgTurnEnded // Turn has ended (no more output) + ServerMsgHeartbeat // Server heartbeat + ServerMsgTokenDelta // Token usage delta + ServerMsgCheckpoint // Conversation checkpoint update ) // DecodedServerMessage holds parsed data from an AgentServerMessage. @@ -561,4 +561,3 @@ func decodeVarintField(data []byte, targetField protowire.Number) int64 { func BlobIdHex(blobId []byte) string { return hex.EncodeToString(blobId) } - diff --git a/internal/auth/cursor/proto/fieldnumbers.go b/internal/auth/cursor/proto/fieldnumbers.go index 7ba24109..4b2accc6 100644 --- a/internal/auth/cursor/proto/fieldnumbers.go +++ b/internal/auth/cursor/proto/fieldnumbers.go @@ -4,23 +4,23 @@ package proto // AgentClientMessage (msg 118) oneof "message" const ( - ACM_RunRequest = 1 // AgentRunRequest - ACM_ExecClientMessage = 2 // ExecClientMessage - ACM_KvClientMessage = 3 // KvClientMessage - ACM_ConversationAction = 4 // ConversationAction - ACM_ExecClientControlMsg = 5 // ExecClientControlMessage - ACM_InteractionResponse = 6 // InteractionResponse - ACM_ClientHeartbeat = 7 // ClientHeartbeat + ACM_RunRequest = 1 // AgentRunRequest + ACM_ExecClientMessage = 2 // ExecClientMessage + ACM_KvClientMessage = 3 // KvClientMessage + ACM_ConversationAction = 4 // ConversationAction + ACM_ExecClientControlMsg = 5 // ExecClientControlMessage + ACM_InteractionResponse = 6 // InteractionResponse + ACM_ClientHeartbeat = 7 // ClientHeartbeat ) // AgentServerMessage (msg 119) oneof "message" const ( - ASM_InteractionUpdate = 1 // InteractionUpdate - ASM_ExecServerMessage = 2 // ExecServerMessage - ASM_ConversationCheckpoint = 3 // ConversationStateStructure - ASM_KvServerMessage = 4 // KvServerMessage - ASM_ExecServerControlMessage = 5 // ExecServerControlMessage - ASM_InteractionQuery = 7 // InteractionQuery + ASM_InteractionUpdate = 1 // InteractionUpdate + ASM_ExecServerMessage = 2 // ExecServerMessage + ASM_ConversationCheckpoint = 3 // ConversationStateStructure + ASM_KvServerMessage = 4 // KvServerMessage + ASM_ExecServerControlMessage = 5 // ExecServerControlMessage + ASM_InteractionQuery = 7 // InteractionQuery ) // AgentRunRequest (msg 91) @@ -77,10 +77,10 @@ const ( // ModelDetails (msg 88) const ( - MD_ModelId = 1 // string + MD_ModelId = 1 // string MD_ThinkingDetails = 2 // ThinkingDetails (optional) - MD_DisplayModelId = 3 // string - MD_DisplayName = 4 // string + MD_DisplayModelId = 3 // string + MD_DisplayName = 4 // string ) // McpTools (msg 307) @@ -122,9 +122,9 @@ const ( // InteractionUpdate oneof "message" const ( - IU_TextDelta = 1 // TextDeltaUpdate - IU_ThinkingDelta = 4 // ThinkingDeltaUpdate - IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate + IU_TextDelta = 1 // TextDeltaUpdate + IU_ThinkingDelta = 4 // ThinkingDeltaUpdate + IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate ) // TextDeltaUpdate (msg 92) @@ -169,22 +169,22 @@ const ( // ExecServerMessage const ( - ESM_Id = 1 // uint32 - ESM_ExecId = 15 // string + ESM_Id = 1 // uint32 + ESM_ExecId = 15 // string // oneof message: - ESM_ShellArgs = 2 // ShellArgs - ESM_WriteArgs = 3 // WriteArgs - ESM_DeleteArgs = 4 // DeleteArgs - ESM_GrepArgs = 5 // GrepArgs - ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped) - ESM_LsArgs = 8 // LsArgs - ESM_DiagnosticsArgs = 9 // DiagnosticsArgs - ESM_RequestContextArgs = 10 // RequestContextArgs - ESM_McpArgs = 11 // McpArgs - ESM_ShellStreamArgs = 14 // ShellArgs (stream variant) - ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs - ESM_FetchArgs = 20 // FetchArgs - ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs + ESM_ShellArgs = 2 // ShellArgs + ESM_WriteArgs = 3 // WriteArgs + ESM_DeleteArgs = 4 // DeleteArgs + ESM_GrepArgs = 5 // GrepArgs + ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped) + ESM_LsArgs = 8 // LsArgs + ESM_DiagnosticsArgs = 9 // DiagnosticsArgs + ESM_RequestContextArgs = 10 // RequestContextArgs + ESM_McpArgs = 11 // McpArgs + ESM_ShellStreamArgs = 14 // ShellArgs (stream variant) + ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs + ESM_FetchArgs = 20 // FetchArgs + ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs ) // ExecClientMessage @@ -192,19 +192,19 @@ const ( ECM_Id = 1 // uint32 ECM_ExecId = 15 // string // oneof message (mirrors server fields): - ECM_ShellResult = 2 - ECM_WriteResult = 3 - ECM_DeleteResult = 4 - ECM_GrepResult = 5 - ECM_ReadResult = 7 - ECM_LsResult = 8 - ECM_DiagnosticsResult = 9 - ECM_RequestContextResult = 10 - ECM_McpResult = 11 - ECM_ShellStream = 14 - ECM_BackgroundShellSpawnRes = 16 - ECM_FetchResult = 20 - ECM_WriteShellStdinResult = 23 + ECM_ShellResult = 2 + ECM_WriteResult = 3 + ECM_DeleteResult = 4 + ECM_GrepResult = 5 + ECM_ReadResult = 7 + ECM_LsResult = 8 + ECM_DiagnosticsResult = 9 + ECM_RequestContextResult = 10 + ECM_McpResult = 11 + ECM_ShellStream = 14 + ECM_BackgroundShellSpawnRes = 16 + ECM_FetchResult = 20 + ECM_WriteShellStdinResult = 23 ) // McpArgs @@ -276,28 +276,28 @@ const ( // ShellResult oneof: success=1 (+ various), rejected=? // The TS code uses specific result field numbers from the oneof: const ( - RR_Rejected = 3 // ReadResult.rejected - SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected) - WR_Rejected = 5 // WriteResult.rejected - DR_Rejected = 3 // DeleteResult.rejected - LR_Rejected = 3 // LsResult.rejected - GR_Error = 2 // GrepResult.error - FR_Error = 2 // FetchResult.error + RR_Rejected = 3 // ReadResult.rejected + SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected) + WR_Rejected = 5 // WriteResult.rejected + DR_Rejected = 3 // DeleteResult.rejected + LR_Rejected = 3 // LsResult.rejected + GR_Error = 2 // GrepResult.error + FR_Error = 2 // FetchResult.error BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field) WSSR_Error = 2 // WriteShellStdinResult.error ) // --- Rejection struct fields --- const ( - REJ_Path = 1 - REJ_Reason = 2 - SREJ_Command = 1 - SREJ_WorkingDir = 2 - SREJ_Reason = 3 - SREJ_IsReadonly = 4 - GERR_Error = 1 - FERR_Url = 1 - FERR_Error = 2 + REJ_Path = 1 + REJ_Reason = 2 + SREJ_Command = 1 + SREJ_WorkingDir = 2 + SREJ_Reason = 3 + SREJ_IsReadonly = 4 + GERR_Error = 1 + FERR_Url = 1 + FERR_Error = 2 ) // ReadArgs diff --git a/internal/auth/cursor/proto/h2stream.go b/internal/auth/cursor/proto/h2stream.go index 45b5baf7..5275b283 100644 --- a/internal/auth/cursor/proto/h2stream.go +++ b/internal/auth/cursor/proto/h2stream.go @@ -33,10 +33,10 @@ type H2Stream struct { err error // Send-side flow control - sendWindow int32 // available bytes we can send on this stream - connWindow int32 // available bytes on the connection level - windowCond *sync.Cond // signaled when window is updated - windowMu sync.Mutex // protects sendWindow, connWindow + sendWindow int32 // available bytes we can send on this stream + connWindow int32 // available bytes on the connection level + windowCond *sync.Cond // signaled when window is updated + windowMu sync.Mutex // protects sendWindow, connWindow } // ID returns the unique identifier for this stream (for logging). diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go index 3033c985..da20bc42 100644 --- a/internal/auth/kiro/aws_test.go +++ b/internal/auth/kiro/aws_test.go @@ -748,4 +748,3 @@ func TestExtractRegionFromMetadata(t *testing.T) { }) } } - diff --git a/internal/auth/kiro/cooldown.go b/internal/auth/kiro/cooldown.go index c1aabbcb..716135b6 100644 --- a/internal/auth/kiro/cooldown.go +++ b/internal/auth/kiro/cooldown.go @@ -6,8 +6,8 @@ import ( ) const ( - CooldownReason429 = "rate_limit_exceeded" - CooldownReasonSuspended = "account_suspended" + CooldownReason429 = "rate_limit_exceeded" + CooldownReasonSuspended = "account_suspended" CooldownReasonQuotaExhausted = "quota_exhausted" DefaultShortCooldown = 1 * time.Minute diff --git a/internal/auth/kiro/jitter.go b/internal/auth/kiro/jitter.go index 0569a8fb..fef2aea9 100644 --- a/internal/auth/kiro/jitter.go +++ b/internal/auth/kiro/jitter.go @@ -26,9 +26,9 @@ const ( ) var ( - jitterRand *rand.Rand - jitterRandOnce sync.Once - jitterMu sync.Mutex + jitterRand *rand.Rand + jitterRandOnce sync.Once + jitterMu sync.Mutex lastRequestTime time.Time ) diff --git a/internal/auth/kiro/metrics.go b/internal/auth/kiro/metrics.go index 0fe2d0c6..f9540fc1 100644 --- a/internal/auth/kiro/metrics.go +++ b/internal/auth/kiro/metrics.go @@ -24,10 +24,10 @@ type TokenScorer struct { metrics map[string]*TokenMetrics // Scoring weights - successRateWeight float64 - quotaWeight float64 - latencyWeight float64 - lastUsedWeight float64 + successRateWeight float64 + quotaWeight float64 + latencyWeight float64 + lastUsedWeight float64 failPenaltyMultiplier float64 } diff --git a/internal/auth/kiro/protocol_handler.go b/internal/auth/kiro/protocol_handler.go index d900ee33..a1c28a86 100644 --- a/internal/auth/kiro/protocol_handler.go +++ b/internal/auth/kiro/protocol_handler.go @@ -97,7 +97,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) { var listener net.Listener var err error portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4} - + for _, port := range portRange { listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err == nil { @@ -105,7 +105,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) { } log.Debugf("kiro protocol handler: port %d busy, trying next", port) } - + if listener == nil { return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4) } diff --git a/internal/auth/qwen/qwen_auth.go b/internal/auth/qwen/qwen_auth.go deleted file mode 100644 index cb58b86d..00000000 --- a/internal/auth/qwen/qwen_auth.go +++ /dev/null @@ -1,359 +0,0 @@ -package qwen - -import ( - "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -const ( - // QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow. - QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code" - // QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens. - QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token" - // QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application. - QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56" - // QwenOAuthScope defines the permissions requested by the application. - QwenOAuthScope = "openid profile email model.completion" - // QwenOAuthGrantType specifies the grant type for the device code flow. - QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code" -) - -// QwenTokenData represents the OAuth credentials, including access and refresh tokens. -type QwenTokenData struct { - AccessToken string `json:"access_token"` - // RefreshToken is used to obtain a new access token when the current one expires. - RefreshToken string `json:"refresh_token,omitempty"` - // TokenType indicates the type of token, typically "Bearer". - TokenType string `json:"token_type"` - // ResourceURL specifies the base URL of the resource server. - ResourceURL string `json:"resource_url,omitempty"` - // Expire indicates the expiration date and time of the access token. - Expire string `json:"expiry_date,omitempty"` -} - -// DeviceFlow represents the response from the device authorization endpoint. -type DeviceFlow struct { - // DeviceCode is the code that the client uses to poll for an access token. - DeviceCode string `json:"device_code"` - // UserCode is the code that the user enters at the verification URI. - UserCode string `json:"user_code"` - // VerificationURI is the URL where the user can enter the user code to authorize the device. - VerificationURI string `json:"verification_uri"` - // VerificationURIComplete is a URI that includes the user_code, which can be used to automatically - // fill in the code on the verification page. - VerificationURIComplete string `json:"verification_uri_complete"` - // ExpiresIn is the time in seconds until the device_code and user_code expire. - ExpiresIn int `json:"expires_in"` - // Interval is the minimum time in seconds that the client should wait between polling requests. - Interval int `json:"interval"` - // CodeVerifier is the cryptographically random string used in the PKCE flow. - CodeVerifier string `json:"code_verifier"` -} - -// QwenTokenResponse represents the successful token response from the token endpoint. -type QwenTokenResponse struct { - // AccessToken is the token used to access protected resources. - AccessToken string `json:"access_token"` - // RefreshToken is used to obtain a new access token. - RefreshToken string `json:"refresh_token,omitempty"` - // TokenType indicates the type of token, typically "Bearer". - TokenType string `json:"token_type"` - // ResourceURL specifies the base URL of the resource server. - ResourceURL string `json:"resource_url,omitempty"` - // ExpiresIn is the time in seconds until the access token expires. - ExpiresIn int `json:"expires_in"` -} - -// QwenAuth manages authentication and token handling for the Qwen API. -type QwenAuth struct { - httpClient *http.Client -} - -// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client. -func NewQwenAuth(cfg *config.Config) *QwenAuth { - return &QwenAuth{ - httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), - } -} - -// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier. -func (qa *QwenAuth) generateCodeVerifier() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(bytes), nil -} - -// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge. -func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string { - hash := sha256.Sum256([]byte(codeVerifier)) - return base64.RawURLEncoding.EncodeToString(hash[:]) -} - -// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE. -func (qa *QwenAuth) generatePKCEPair() (string, string, error) { - codeVerifier, err := qa.generateCodeVerifier() - if err != nil { - return "", "", err - } - codeChallenge := qa.generateCodeChallenge(codeVerifier) - return codeVerifier, codeChallenge, nil -} - -// RefreshTokens exchanges a refresh token for a new access token. -func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) { - data := url.Values{} - data.Set("grant_type", "refresh_token") - data.Set("refresh_token", refreshToken) - data.Set("client_id", QwenOAuthClientID) - - req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := qa.httpClient.Do(req) - - // resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data) - if err != nil { - return nil, fmt.Errorf("token refresh request failed: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - var errorData map[string]interface{} - if err = json.Unmarshal(body, &errorData); err == nil { - return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"]) - } - return nil, fmt.Errorf("token refresh failed: %s", string(body)) - } - - var tokenData QwenTokenResponse - if err = json.Unmarshal(body, &tokenData); err != nil { - return nil, fmt.Errorf("failed to parse token response: %w", err) - } - - return &QwenTokenData{ - AccessToken: tokenData.AccessToken, - TokenType: tokenData.TokenType, - RefreshToken: tokenData.RefreshToken, - ResourceURL: tokenData.ResourceURL, - Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339), - }, nil -} - -// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details. -func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) { - // Generate PKCE code verifier and challenge - codeVerifier, codeChallenge, err := qa.generatePKCEPair() - if err != nil { - return nil, fmt.Errorf("failed to generate PKCE pair: %w", err) - } - - data := url.Values{} - data.Set("client_id", QwenOAuthClientID) - data.Set("scope", QwenOAuthScope) - data.Set("code_challenge", codeChallenge) - data.Set("code_challenge_method", "S256") - - req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := qa.httpClient.Do(req) - - // resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data) - if err != nil { - return nil, fmt.Errorf("device authorization request failed: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) - } - - var result DeviceFlow - if err = json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse device flow response: %w", err) - } - - // Check if the response indicates success - if result.DeviceCode == "" { - return nil, fmt.Errorf("device authorization failed: device_code not found in response") - } - - // Add the code_verifier to the result so it can be used later for polling - result.CodeVerifier = codeVerifier - - return &result, nil -} - -// PollForToken polls the token endpoint with the device code to obtain an access token. -func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) { - pollInterval := 5 * time.Second - maxAttempts := 60 // 5 minutes max - - for attempt := 0; attempt < maxAttempts; attempt++ { - data := url.Values{} - data.Set("grant_type", QwenOAuthGrantType) - data.Set("client_id", QwenOAuthClientID) - data.Set("device_code", deviceCode) - data.Set("code_verifier", codeVerifier) - - resp, err := http.PostForm(QwenOAuthTokenEndpoint, data) - if err != nil { - fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) - time.Sleep(pollInterval) - continue - } - - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) - time.Sleep(pollInterval) - continue - } - - if resp.StatusCode != http.StatusOK { - // Parse the response as JSON to check for OAuth RFC 8628 standard errors - var errorData map[string]interface{} - if err = json.Unmarshal(body, &errorData); err == nil { - // According to OAuth RFC 8628, handle standard polling responses - if resp.StatusCode == http.StatusBadRequest { - errorType, _ := errorData["error"].(string) - switch errorType { - case "authorization_pending": - // User has not yet approved the authorization request. Continue polling. - fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts) - time.Sleep(pollInterval) - continue - case "slow_down": - // Client is polling too frequently. Increase poll interval. - pollInterval = time.Duration(float64(pollInterval) * 1.5) - if pollInterval > 10*time.Second { - pollInterval = 10 * time.Second - } - fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval) - time.Sleep(pollInterval) - continue - case "expired_token": - return nil, fmt.Errorf("device code expired. Please restart the authentication process") - case "access_denied": - return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process") - } - } - - // For other errors, return with proper error information - errorType, _ := errorData["error"].(string) - errorDesc, _ := errorData["error_description"].(string) - return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc) - } - - // If JSON parsing fails, fall back to text response - return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) - } - // log.Debugf("%s", string(body)) - // Success - parse token data - var response QwenTokenResponse - if err = json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("failed to parse token response: %w", err) - } - - // Convert to QwenTokenData format and save - tokenData := &QwenTokenData{ - AccessToken: response.AccessToken, - RefreshToken: response.RefreshToken, - TokenType: response.TokenType, - ResourceURL: response.ResourceURL, - Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339), - } - - return tokenData, nil - } - - return nil, fmt.Errorf("authentication timeout. Please restart the authentication process") -} - -// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure. -func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) { - var lastErr error - - for attempt := 0; attempt < maxRetries; attempt++ { - if attempt > 0 { - // Wait before retry - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(time.Duration(attempt) * time.Second): - } - } - - tokenData, err := o.RefreshTokens(ctx, refreshToken) - if err == nil { - return tokenData, nil - } - - lastErr = err - log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err) - } - - return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr) -} - -// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object. -func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage { - storage := &QwenTokenStorage{ - AccessToken: tokenData.AccessToken, - RefreshToken: tokenData.RefreshToken, - LastRefresh: time.Now().Format(time.RFC3339), - ResourceURL: tokenData.ResourceURL, - Expire: tokenData.Expire, - } - - return storage -} - -// UpdateTokenStorage updates an existing token storage with new token data -func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) { - storage.AccessToken = tokenData.AccessToken - storage.RefreshToken = tokenData.RefreshToken - storage.LastRefresh = time.Now().Format(time.RFC3339) - storage.ResourceURL = tokenData.ResourceURL - storage.Expire = tokenData.Expire -} diff --git a/internal/auth/qwen/qwen_token.go b/internal/auth/qwen/qwen_token.go deleted file mode 100644 index 276c8b40..00000000 --- a/internal/auth/qwen/qwen_token.go +++ /dev/null @@ -1,79 +0,0 @@ -// Package qwen provides authentication and token management functionality -// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization, -// and retrieval for maintaining authenticated sessions with the Qwen API. -package qwen - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" -) - -// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication. -// It maintains compatibility with the existing auth system while adding Qwen-specific fields -// for managing access tokens, refresh tokens, and user account information. -type QwenTokenStorage struct { - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` - // RefreshToken is used to obtain new access tokens when the current one expires. - RefreshToken string `json:"refresh_token"` - // LastRefresh is the timestamp of the last token refresh operation. - LastRefresh string `json:"last_refresh"` - // ResourceURL is the base URL for API requests. - ResourceURL string `json:"resource_url"` - // Email is the Qwen account email address associated with this token. - Email string `json:"email"` - // Type indicates the authentication provider type, always "qwen" for this storage. - Type string `json:"type"` - // Expire is the timestamp when the current access token expires. - Expire string `json:"expired"` - - // Metadata holds arbitrary key-value pairs injected via hooks. - // It is not exported to JSON directly to allow flattening during serialization. - Metadata map[string]any `json:"-"` -} - -// SetMetadata allows external callers to inject metadata into the storage before saving. -func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) { - ts.Metadata = meta -} - -// SaveTokenToFile serializes the Qwen token storage to a JSON file. -// This method creates the necessary directory structure and writes the token -// data in JSON format to the specified file path for persistent storage. -// It merges any injected metadata into the top-level JSON object. -// -// Parameters: -// - authFilePath: The full path where the token file should be saved -// -// Returns: -// - error: An error if the operation fails, nil otherwise -func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) - ts.Type = "qwen" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(authFilePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - // Merge metadata using helper - data, errMerge := misc.MergeMetadata(ts, ts.Metadata) - if errMerge != nil { - return fmt.Errorf("failed to merge metadata: %w", errMerge) - } - - if err = json.NewEncoder(f).Encode(data); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) - } - return nil -} diff --git a/internal/browser/browser.go b/internal/browser/browser.go index 3a5aeea7..e8551788 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -39,7 +39,7 @@ func CloseBrowser() error { if lastBrowserProcess == nil || lastBrowserProcess.Process == nil { return nil } - + err := lastBrowserProcess.Process.Kill() lastBrowserProcess = nil return err diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index 95fede4d..fd2ccab7 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -207,9 +207,12 @@ func init() { // SetSignatureCacheEnabled switches Antigravity signature handling between cache mode and bypass mode. func SetSignatureCacheEnabled(enabled bool) { - signatureCacheEnabled.Store(enabled) + previous := signatureCacheEnabled.Swap(enabled) + if previous == enabled { + return + } if !enabled { - log.Warn("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation") + log.Info("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation") } } @@ -220,11 +223,14 @@ func SignatureCacheEnabled() bool { // SetSignatureBypassStrictMode controls whether bypass mode uses strict protobuf-tree validation. func SetSignatureBypassStrictMode(strict bool) { - signatureBypassStrictMode.Store(strict) + previous := signatureBypassStrictMode.Swap(strict) + if previous == strict { + return + } if strict { - log.Info("antigravity bypass signature validation: strict mode (protobuf tree)") + log.Debug("antigravity bypass signature validation: strict mode (protobuf tree)") } else { - log.Info("antigravity bypass signature validation: basic mode (R/E + 0x12)") + log.Debug("antigravity bypass signature validation: basic mode (R/E + 0x12)") } } diff --git a/internal/cache/signature_cache_test.go b/internal/cache/signature_cache_test.go index 83408159..82a8a19d 100644 --- a/internal/cache/signature_cache_test.go +++ b/internal/cache/signature_cache_test.go @@ -1,8 +1,12 @@ package cache import ( + "bytes" + "strings" "testing" "time" + + log "github.com/sirupsen/logrus" ) const testModelName = "claude-sonnet-4-5" @@ -208,3 +212,90 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) { // but the logic is verified by the implementation _ = time.Now() // Acknowledge we're not testing time passage } + +func TestSignatureModeSetters_LogAtInfoLevel(t *testing.T) { + logger := log.StandardLogger() + previousOutput := logger.Out + previousLevel := logger.Level + previousCache := SignatureCacheEnabled() + previousStrict := SignatureBypassStrictMode() + SetSignatureCacheEnabled(true) + SetSignatureBypassStrictMode(false) + buffer := &bytes.Buffer{} + log.SetOutput(buffer) + log.SetLevel(log.InfoLevel) + t.Cleanup(func() { + log.SetOutput(previousOutput) + log.SetLevel(previousLevel) + SetSignatureCacheEnabled(previousCache) + SetSignatureBypassStrictMode(previousStrict) + }) + + SetSignatureCacheEnabled(false) + SetSignatureBypassStrictMode(true) + SetSignatureBypassStrictMode(false) + + output := buffer.String() + if !strings.Contains(output, "antigravity signature cache DISABLED") { + t.Fatalf("expected info output for disabling signature cache, got: %q", output) + } + if strings.Contains(output, "strict mode (protobuf tree)") { + t.Fatalf("expected strict bypass mode log to stay below info level, got: %q", output) + } + if strings.Contains(output, "basic mode (R/E + 0x12)") { + t.Fatalf("expected basic bypass mode log to stay below info level, got: %q", output) + } +} + +func TestSignatureModeSetters_DoNotRepeatSameStateLogs(t *testing.T) { + logger := log.StandardLogger() + previousOutput := logger.Out + previousLevel := logger.Level + previousCache := SignatureCacheEnabled() + previousStrict := SignatureBypassStrictMode() + SetSignatureCacheEnabled(false) + SetSignatureBypassStrictMode(true) + buffer := &bytes.Buffer{} + log.SetOutput(buffer) + log.SetLevel(log.InfoLevel) + t.Cleanup(func() { + log.SetOutput(previousOutput) + log.SetLevel(previousLevel) + SetSignatureCacheEnabled(previousCache) + SetSignatureBypassStrictMode(previousStrict) + }) + + SetSignatureCacheEnabled(false) + SetSignatureBypassStrictMode(true) + + if buffer.Len() != 0 { + t.Fatalf("expected repeated setter calls with unchanged state to stay silent, got: %q", buffer.String()) + } +} + +func TestSignatureBypassStrictMode_LogsAtDebugLevel(t *testing.T) { + logger := log.StandardLogger() + previousOutput := logger.Out + previousLevel := logger.Level + previousStrict := SignatureBypassStrictMode() + SetSignatureBypassStrictMode(false) + buffer := &bytes.Buffer{} + log.SetOutput(buffer) + log.SetLevel(log.DebugLevel) + t.Cleanup(func() { + log.SetOutput(previousOutput) + log.SetLevel(previousLevel) + SetSignatureBypassStrictMode(previousStrict) + }) + + SetSignatureBypassStrictMode(true) + SetSignatureBypassStrictMode(false) + + output := buffer.String() + if !strings.Contains(output, "strict mode (protobuf tree)") { + t.Fatalf("expected debug output for strict bypass mode, got: %q", output) + } + if !strings.Contains(output, "basic mode (R/E + 0x12)") { + t.Fatalf("expected debug output for basic bypass mode, got: %q", output) + } +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index df62281e..03938648 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -15,7 +15,6 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), - sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), sdkAuth.NewKimiAuthenticator(), diff --git a/internal/cmd/qwen_login.go b/internal/cmd/qwen_login.go deleted file mode 100644 index 10179fa8..00000000 --- a/internal/cmd/qwen_login.go +++ /dev/null @@ -1,60 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" - log "github.com/sirupsen/logrus" -) - -// DoQwenLogin handles the Qwen device flow using the shared authentication manager. -// It initiates the device-based authentication process for Qwen services and saves -// the authentication tokens to the configured auth directory. -// -// Parameters: -// - cfg: The application configuration -// - options: Login options including browser behavior and prompts -func DoQwenLogin(cfg *config.Config, options *LoginOptions) { - if options == nil { - options = &LoginOptions{} - } - - manager := newAuthManager() - - promptFn := options.Prompt - if promptFn == nil { - promptFn = func(prompt string) (string, error) { - fmt.Println() - fmt.Println(prompt) - var value string - _, err := fmt.Scanln(&value) - return value, err - } - } - - authOpts := &sdkAuth.LoginOptions{ - NoBrowser: options.NoBrowser, - CallbackPort: options.CallbackPort, - Metadata: map[string]string{}, - Prompt: promptFn, - } - - _, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts) - if err != nil { - if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok { - log.Error(emailErr.Error()) - return - } - fmt.Printf("Qwen authentication failed: %v\n", err) - return - } - - if savedPath != "" { - fmt.Printf("Authentication saved to %s\n", savedPath) - } - - fmt.Println("Qwen authentication successful!") -} diff --git a/internal/config/config.go b/internal/config/config.go index af8ed1fc..dc870267 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,6 +68,10 @@ type Config struct { // DisableCooling disables quota cooldown scheduling when true. DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"` + // AuthAutoRefreshWorkers overrides the size of the core auth auto-refresh worker pool. + // When <= 0, the default worker count is used. + AuthAutoRefreshWorkers int `yaml:"auth-auto-refresh-workers" json:"auth-auto-refresh-workers"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryCredentials defines the maximum number of credentials to try for a failed request. @@ -131,12 +135,12 @@ type Config struct { AmpCode AmpCode `yaml:"ampcode" json:"ampcode"` // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. - // Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot. + // Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot. OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` // OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels. // These aliases affect both model listing and model routing for supported channels: - // gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot. + // gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot. // // NOTE: This does not apply to existing per-credential model alias features under: // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. @@ -229,6 +233,22 @@ type RoutingConfig struct { // Strategy selects the credential selection strategy. // Supported values: "round-robin" (default), "fill-first". Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` + + // ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients. + // When enabled, requests with the same session ID (extracted from metadata.user_id) + // are routed to the same auth credential when available. + // Deprecated: Use SessionAffinity instead for universal session support. + ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"` + + // SessionAffinity enables universal session-sticky routing for all clients. + // Session IDs are extracted from multiple sources: + // X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash. + // Automatic failover is always enabled when bound auth becomes unavailable. + SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"` + + // SessionAffinityTTL specifies how long session-to-auth bindings are retained. + // Default: 1h. Accepts duration strings like "30m", "1h", "2h30m". + SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"` } // OAuthModelAlias defines a model ID alias for a specific channel. diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 6bef9de8..f479f616 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -17,7 +17,6 @@ type staticModelsJSON struct { CodexTeam []*ModelInfo `json:"codex-team"` CodexPlus []*ModelInfo `json:"codex-plus"` CodexPro []*ModelInfo `json:"codex-pro"` - Qwen []*ModelInfo `json:"qwen"` IFlow []*ModelInfo `json:"iflow"` Kimi []*ModelInfo `json:"kimi"` Antigravity []*ModelInfo `json:"antigravity"` @@ -68,11 +67,6 @@ func GetCodexProModels() []*ModelInfo { return cloneModelInfos(getModels().CodexPro) } -// GetQwenModels returns the standard Qwen model definitions. -func GetQwenModels() []*ModelInfo { - return cloneModelInfos(getModels().Qwen) -} - // GetIFlowModels returns the standard iFlow model definitions. func GetIFlowModels() []*ModelInfo { return cloneModelInfos(getModels().IFlow) @@ -239,7 +233,6 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo { // - gemini-cli // - aistudio // - codex -// - qwen // - iflow // - kimi // - kilo @@ -261,8 +254,6 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetAIStudioModels() case "codex": return GetCodexProModels() - case "qwen": - return GetQwenModels() case "iflow": return GetIFlowModels() case "kimi": @@ -313,7 +304,6 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.GeminiCLI, data.AIStudio, data.CodexPro, - data.Qwen, data.IFlow, data.Kimi, data.Antigravity, diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 197f6044..9ed09c2f 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -213,7 +213,6 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string { {"codex", oldData.CodexTeam, newData.CodexTeam}, {"codex", oldData.CodexPlus, newData.CodexPlus}, {"codex", oldData.CodexPro, newData.CodexPro}, - {"qwen", oldData.Qwen, newData.Qwen}, {"iflow", oldData.IFlow, newData.IFlow}, {"kimi", oldData.Kimi, newData.Kimi}, {"antigravity", oldData.Antigravity, newData.Antigravity}, @@ -335,7 +334,6 @@ func validateModelsCatalog(data *staticModelsJSON) error { {name: "codex-team", models: data.CodexTeam}, {name: "codex-plus", models: data.CodexPlus}, {name: "codex-pro", models: data.CodexPro}, - {name: "qwen", models: data.Qwen}, {name: "iflow", models: data.IFlow}, {name: "kimi", models: data.Kimi}, {name: "antigravity", models: data.Antigravity}, diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index acf368ab..d4788ec1 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -1177,163 +1177,6 @@ } ], "codex-free": [ - { - "id": "gpt-5", - "object": "model", - "created": 1754524800, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5-2025-08-07", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex", - "object": "model", - "created": 1757894400, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex", - "version": "gpt-5-2025-09-15", - "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex-mini", - "object": "model", - "created": 1762473600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex Mini", - "version": "gpt-5-2025-11-07", - "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "none", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-mini", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Mini", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-max", - "object": "model", - "created": 1763424000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Max", - "version": "gpt-5.1-max", - "description": "Stable version of GPT 5.1 Codex Max", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.2", "object": "model", @@ -1358,29 +1201,6 @@ ] } }, - { - "id": "gpt-5.2-codex", - "object": "model", - "created": 1765440000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.2 Codex", - "version": "gpt-5.2", - "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.3-codex", "object": "model", @@ -1452,163 +1272,6 @@ } ], "codex-team": [ - { - "id": "gpt-5", - "object": "model", - "created": 1754524800, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5-2025-08-07", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex", - "object": "model", - "created": 1757894400, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex", - "version": "gpt-5-2025-09-15", - "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex-mini", - "object": "model", - "created": 1762473600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex Mini", - "version": "gpt-5-2025-11-07", - "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "none", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-mini", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Mini", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-max", - "object": "model", - "created": 1763424000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Max", - "version": "gpt-5.1-max", - "description": "Stable version of GPT 5.1 Codex Max", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.2", "object": "model", @@ -1633,29 +1296,6 @@ ] } }, - { - "id": "gpt-5.2-codex", - "object": "model", - "created": 1765440000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.2 Codex", - "version": "gpt-5.2", - "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.3-codex", "object": "model", @@ -1727,163 +1367,6 @@ } ], "codex-plus": [ - { - "id": "gpt-5", - "object": "model", - "created": 1754524800, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5-2025-08-07", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex", - "object": "model", - "created": 1757894400, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex", - "version": "gpt-5-2025-09-15", - "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex-mini", - "object": "model", - "created": 1762473600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex Mini", - "version": "gpt-5-2025-11-07", - "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "none", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-mini", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Mini", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-max", - "object": "model", - "created": 1763424000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Max", - "version": "gpt-5.1-max", - "description": "Stable version of GPT 5.1 Codex Max", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.2", "object": "model", @@ -1908,29 +1391,6 @@ ] } }, - { - "id": "gpt-5.2-codex", - "object": "model", - "created": 1765440000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.2 Codex", - "version": "gpt-5.2", - "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.3-codex", "object": "model", @@ -2025,163 +1485,6 @@ } ], "codex-pro": [ - { - "id": "gpt-5", - "object": "model", - "created": 1754524800, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5-2025-08-07", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex", - "object": "model", - "created": 1757894400, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex", - "version": "gpt-5-2025-09-15", - "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5-codex-mini", - "object": "model", - "created": 1762473600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5 Codex Mini", - "version": "gpt-5-2025-11-07", - "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "none", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-mini", - "object": "model", - "created": 1762905600, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Mini", - "version": "gpt-5.1-2025-11-12", - "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gpt-5.1-codex-max", - "object": "model", - "created": 1763424000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.1 Codex Max", - "version": "gpt-5.1-max", - "description": "Stable version of GPT 5.1 Codex Max", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.2", "object": "model", @@ -2206,29 +1509,6 @@ ] } }, - { - "id": "gpt-5.2-codex", - "object": "model", - "created": 1765440000, - "owned_by": "openai", - "type": "openai", - "display_name": "GPT 5.2 Codex", - "version": "gpt-5.2", - "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", - "context_length": 400000, - "max_completion_tokens": 128000, - "supported_parameters": [ - "tools" - ], - "thinking": { - "levels": [ - "low", - "medium", - "high", - "xhigh" - ] - } - }, { "id": "gpt-5.3-codex", "object": "model", @@ -2322,27 +1602,6 @@ } } ], - "qwen": [ - { - "id": "coder-model", - "object": "model", - "created": 1771171200, - "owned_by": "qwen", - "type": "qwen", - "display_name": "Qwen 3.6 Plus", - "version": "3.6", - "description": "efficient hybrid model with leading coding performance", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": [ - "temperature", - "top_p", - "max_tokens", - "stream", - "stop" - ] - } - ], "iflow": [ { "id": "qwen3-coder-plus", @@ -2606,38 +1865,6 @@ "dynamic_allowed": true } }, - { - "id": "gemini-2.5-flash", - "object": "model", - "owned_by": "antigravity", - "type": "antigravity", - "display_name": "Gemini 2.5 Flash", - "name": "gemini-2.5-flash", - "description": "Gemini 2.5 Flash", - "context_length": 1048576, - "max_completion_tokens": 65535, - "thinking": { - "max": 24576, - "zero_allowed": true, - "dynamic_allowed": true - } - }, - { - "id": "gemini-2.5-flash-lite", - "object": "model", - "owned_by": "antigravity", - "type": "antigravity", - "display_name": "Gemini 2.5 Flash Lite", - "name": "gemini-2.5-flash-lite", - "description": "Gemini 2.5 Flash Lite", - "context_length": 1048576, - "max_completion_tokens": 65535, - "thinking": { - "max": 24576, - "zero_allowed": true, - "dynamic_allowed": true - } - }, { "id": "gemini-3-flash", "object": "model", @@ -2770,6 +1997,29 @@ "description": "GPT-OSS 120B (Medium)", "context_length": 114000, "max_completion_tokens": 32768 + }, + { + "id": "gemini-3.1-flash-lite", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3.1 Flash Lite", + "name": "gemini-3.1-flash-lite", + "description": "Gemini 3.1 Flash Lite", + "context_length": 1048576, + "max_completion_tokens": 65535, + "thinking": { + "min": 1, + "max": 65535, + "zero_allowed": true, + "dynamic_allowed": true, + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } } ] } \ No newline at end of file diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 4796fa9a..163b2d92 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -26,6 +26,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" antigravityclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude" @@ -184,22 +185,24 @@ func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cli return client } -func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []byte) error { +func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []byte) ([]byte, error) { if from.String() != "claude" { - return nil + return rawJSON, nil } + // Always strip thinking blocks with empty signatures (proxy-generated). + rawJSON = antigravityclaude.StripEmptySignatureThinkingBlocks(rawJSON) if cache.SignatureCacheEnabled() { - return nil + return rawJSON, nil } if !cache.SignatureBypassStrictMode() { // Non-strict bypass: let the translator handle invalid signatures // by dropping unsigned thinking blocks silently (no 400). - return nil + return rawJSON, nil } if err := antigravityclaude.ValidateClaudeBypassSignatures(rawJSON); err != nil { - return statusErr{code: http.StatusBadRequest, msg: err.Error()} + return rawJSON, statusErr{code: http.StatusBadRequest, msg: err.Error()} } - return nil + return rawJSON, nil } // Identifier returns the executor identifier. @@ -695,9 +698,11 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au originalPayloadSource = opts.OriginalRequest } originalPayload := originalPayloadSource - if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil { + originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload) + if errValidate != nil { return resp, errValidate } + req.Payload = originalPayload token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return resp, errToken @@ -907,9 +912,11 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * originalPayloadSource = opts.OriginalRequest } originalPayload := originalPayloadSource - if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil { + originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload) + if errValidate != nil { return resp, errValidate } + req.Payload = originalPayload token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return resp, errToken @@ -1370,9 +1377,11 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya originalPayloadSource = opts.OriginalRequest } originalPayload := originalPayloadSource - if errValidate := validateAntigravityRequestSignatures(from, originalPayload); errValidate != nil { + originalPayload, errValidate := validateAntigravityRequestSignatures(from, originalPayload) + if errValidate != nil { return nil, errValidate } + req.Payload = originalPayload token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return nil, errToken @@ -1626,9 +1635,11 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if len(opts.OriginalRequest) > 0 { originalPayloadSource = opts.OriginalRequest } - if errValidate := validateAntigravityRequestSignatures(from, originalPayloadSource); errValidate != nil { + originalPayloadSource, errValidate := validateAntigravityRequestSignatures(from, originalPayloadSource) + if errValidate != nil { return cliproxyexecutor.Response{}, errValidate } + req.Payload = originalPayloadSource token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return cliproxyexecutor.Response{}, errToken @@ -1945,18 +1956,56 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", modelName) - useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro") - payloadStr := string(payload) - paths := make([]string, 0) - util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths) - for _, p := range paths { - payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") + // Cap maxOutputTokens to model's max_completion_tokens from registry + if maxOut := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxOut.Exists() && maxOut.Type == gjson.Number { + if modelInfo := registry.LookupModelInfo(modelName, "antigravity"); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 { + if int(maxOut.Int()) > modelInfo.MaxCompletionTokens { + payload, _ = sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", modelInfo.MaxCompletionTokens) + } + } } - if useAntigravitySchema { - payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr) + useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro") + var ( + bodyReader io.Reader + payloadLog []byte + ) + if antigravityRequestNeedsSchemaSanitization(payload) { + payloadStr := string(payload) + paths := make([]string, 0) + util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths) + for _, p := range paths { + payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") + } + + if useAntigravitySchema { + payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr) + } else { + payloadStr = util.CleanJSONSchemaForGemini(payloadStr) + } + + if strings.Contains(modelName, "claude") { + updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED") + payloadStr = string(updated) + } else { + payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens") + } + + bodyReader = strings.NewReader(payloadStr) + if e.cfg != nil && e.cfg.RequestLog { + payloadLog = []byte(payloadStr) + } } else { - payloadStr = util.CleanJSONSchemaForGemini(payloadStr) + if strings.Contains(modelName, "claude") { + payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") + } else { + payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens") + } + + bodyReader = bytes.NewReader(payload) + if e.cfg != nil && e.cfg.RequestLog { + payloadLog = append([]byte(nil), payload...) + } } // if useAntigravitySchema { @@ -1972,14 +2021,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau // } // } - if strings.Contains(modelName, "claude") { - updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED") - payloadStr = string(updated) - } else { - payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens") - } - - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr)) + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bodyReader) if errReq != nil { return nil, errReq } @@ -2002,10 +2044,6 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau authLabel = auth.Label authType, authValue = auth.AccountInfo() } - var payloadLog []byte - if e.cfg != nil && e.cfg.RequestLog { - payloadLog = []byte(payloadStr) - } helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ URL: requestURL.String(), Method: http.MethodPost, @@ -2021,6 +2059,19 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau return httpReq, nil } +func antigravityRequestNeedsSchemaSanitization(payload []byte) bool { + if gjson.GetBytes(payload, "request.tools.0").Exists() { + return true + } + if gjson.GetBytes(payload, "request.generationConfig.responseJsonSchema").Exists() { + return true + } + if gjson.GetBytes(payload, "request.generationConfig.responseSchema").Exists() { + return true + } + return false +} + func tokenExpiry(metadata map[string]any) time.Time { if metadata == nil { return time.Time{} diff --git a/internal/runtime/executor/antigravity_executor_buildrequest_test.go b/internal/runtime/executor/antigravity_executor_buildrequest_test.go index 27dbeca4..ed2d79e6 100644 --- a/internal/runtime/executor/antigravity_executor_buildrequest_test.go +++ b/internal/runtime/executor/antigravity_executor_buildrequest_test.go @@ -35,12 +35,102 @@ func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) { assertSchemaSanitizedAndPropertyPreserved(t, params) } -func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any { +func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithoutToolsField(t *testing.T) { + body := buildRequestBodyFromRawPayload(t, "gemini-3.1-flash-image", []byte(`{ + "request": { + "contents": [ + { + "role": "user", + "x-debug": "keep-me", + "parts": [ + { + "text": "hello" + } + ] + } + ], + "nonSchema": { + "nullable": true, + "x-extra": "keep-me" + }, + "generationConfig": { + "maxOutputTokens": 128 + } + } + }`)) + + assertNonSchemaRequestPreserved(t, body) +} + +func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithEmptyToolsArray(t *testing.T) { + body := buildRequestBodyFromRawPayload(t, "gemini-3.1-flash-image", []byte(`{ + "request": { + "tools": [], + "contents": [ + { + "role": "user", + "x-debug": "keep-me", + "parts": [ + { + "text": "hello" + } + ] + } + ], + "nonSchema": { + "nullable": true, + "x-extra": "keep-me" + }, + "generationConfig": { + "maxOutputTokens": 128 + } + } + }`)) + + assertNonSchemaRequestPreserved(t, body) +} + +func assertNonSchemaRequestPreserved(t *testing.T, body map[string]any) { t.Helper() - executor := &AntigravityExecutor{} - auth := &cliproxyauth.Auth{} - payload := []byte(`{ + request, ok := body["request"].(map[string]any) + if !ok { + t.Fatalf("request missing or invalid type") + } + + contents, ok := request["contents"].([]any) + if !ok || len(contents) == 0 { + t.Fatalf("contents missing or empty") + } + content, ok := contents[0].(map[string]any) + if !ok { + t.Fatalf("content missing or invalid type") + } + if got, ok := content["x-debug"].(string); !ok || got != "keep-me" { + t.Fatalf("x-debug should be preserved when no tool schema exists, got=%v", content["x-debug"]) + } + + nonSchema, ok := request["nonSchema"].(map[string]any) + if !ok { + t.Fatalf("nonSchema missing or invalid type") + } + if _, ok := nonSchema["nullable"]; !ok { + t.Fatalf("nullable should be preserved outside schema cleanup path") + } + if got, ok := nonSchema["x-extra"].(string); !ok || got != "keep-me" { + t.Fatalf("x-extra should be preserved outside schema cleanup path, got=%v", nonSchema["x-extra"]) + } + + if generationConfig, ok := request["generationConfig"].(map[string]any); ok { + if _, ok := generationConfig["maxOutputTokens"]; ok { + t.Fatalf("maxOutputTokens should still be removed for non-Claude requests") + } + } +} + +func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any { + t.Helper() + return buildRequestBodyFromRawPayload(t, modelName, []byte(`{ "request": { "tools": [ { @@ -75,7 +165,14 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any } ] } - }`) + }`)) +} + +func buildRequestBodyFromRawPayload(t *testing.T, modelName string, payload []byte) map[string]any { + t.Helper() + + executor := &AntigravityExecutor{} + auth := &cliproxyauth.Auth{} req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com") if err != nil { diff --git a/internal/runtime/executor/antigravity_executor_signature_test.go b/internal/runtime/executor/antigravity_executor_signature_test.go index ad4ea443..226daf5c 100644 --- a/internal/runtime/executor/antigravity_executor_signature_test.go +++ b/internal/runtime/executor/antigravity_executor_signature_test.go @@ -21,6 +21,14 @@ func testGeminiSignaturePayload() string { return base64.StdEncoding.EncodeToString(payload) } +// testFakeClaudeSignature returns a base64 string starting with 'E' that passes +// the lightweight hasValidClaudeSignature check but has invalid protobuf content +// (first decoded byte 0x12 is correct, but no valid protobuf field 2 follows), +// so it fails deep validation in strict mode. +func testFakeClaudeSignature() string { + return base64.StdEncoding.EncodeToString([]byte{0x12, 0xFF, 0xFE, 0xFD}) +} + func testAntigravityAuth(baseURL string) *cliproxyauth.Auth { return &cliproxyauth.Auth{ Attributes: map[string]string{ @@ -40,7 +48,7 @@ func invalidClaudeThinkingPayload() []byte { { "role": "assistant", "content": [ - {"type": "thinking", "thinking": "bad", "signature": "` + testGeminiSignaturePayload() + `"}, + {"type": "thinking", "thinking": "bad", "signature": "` + testFakeClaudeSignature() + `"}, {"type": "text", "text": "hello"} ] } @@ -134,7 +142,7 @@ func TestAntigravityExecutor_NonStrictBypassSkipsPrecheck(t *testing.T) { payload := invalidClaudeThinkingPayload() from := sdktranslator.FromString("claude") - err := validateAntigravityRequestSignatures(from, payload) + _, err := validateAntigravityRequestSignatures(from, payload) if err != nil { t.Fatalf("non-strict bypass should skip precheck, got: %v", err) } @@ -150,7 +158,7 @@ func TestAntigravityExecutor_CacheModeSkipsPrecheck(t *testing.T) { payload := invalidClaudeThinkingPayload() from := sdktranslator.FromString("claude") - err := validateAntigravityRequestSignatures(from, payload) + _, err := validateAntigravityRequestSignatures(from, payload) if err != nil { t.Fatalf("cache mode should skip precheck, got: %v", err) } diff --git a/internal/runtime/executor/cursor_executor.go b/internal/runtime/executor/cursor_executor.go index 73335f50..8ea3b323 100644 --- a/internal/runtime/executor/cursor_executor.go +++ b/internal/runtime/executor/cursor_executor.go @@ -4,11 +4,11 @@ import ( "bytes" "context" "crypto/sha256" - "errors" "crypto/tls" "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -30,14 +30,14 @@ import ( ) const ( - cursorAPIURL = "https://api2.cursor.sh" - cursorRunPath = "/agent.v1.AgentService/Run" - cursorModelsPath = "/agent.v1.AgentService/GetUsableModels" - cursorClientVersion = "cli-2026.02.13-41ac335" - cursorAuthType = "cursor" + cursorAPIURL = "https://api2.cursor.sh" + cursorRunPath = "/agent.v1.AgentService/Run" + cursorModelsPath = "/agent.v1.AgentService/GetUsableModels" + cursorClientVersion = "cli-2026.02.13-41ac335" + cursorAuthType = "cursor" cursorHeartbeatInterval = 5 * time.Second - cursorSessionTTL = 5 * time.Minute - cursorCheckpointTTL = 30 * time.Minute + cursorSessionTTL = 5 * time.Minute + cursorCheckpointTTL = 30 * time.Minute ) // CursorExecutor handles requests to the Cursor API via Connect+Protobuf protocol. @@ -63,9 +63,9 @@ type cursorSession struct { pending []pendingMcpExec cancel context.CancelFunc // cancels the session-scoped heartbeat (NOT tied to HTTP request) createdAt time.Time - authID string // auth file ID that created this session (for multi-account isolation) - toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request - resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response + authID string // auth file ID that created this session (for multi-account isolation) + toolResultCh chan []toolResultInfo // receives tool results from the next HTTP request + resumeOutCh chan cliproxyexecutor.StreamChunk // output channel for resumed response switchOutput func(ch chan cliproxyexecutor.StreamChunk) // callback to switch output channel } @@ -148,7 +148,7 @@ type cursorStatusErr struct { msg string } -func (e cursorStatusErr) Error() string { return e.msg } +func (e cursorStatusErr) Error() string { return e.msg } func (e cursorStatusErr) StatusCode() int { return e.code } func (e cursorStatusErr) RetryAfter() *time.Duration { return nil } // no retry-after info from Cursor; conductor uses exponential backoff @@ -786,7 +786,7 @@ func (e *CursorExecutor) resumeWithToolResults( func openCursorH2Stream(accessToken string) (*cursorproto.H2Stream, error) { headers := map[string]string{ ":path": cursorRunPath, - "content-type": "application/connect+proto", + "content-type": "application/connect+proto", "connect-protocol-version": "1", "te": "trailers", "authorization": "Bearer " + accessToken, @@ -876,21 +876,21 @@ func processH2SessionFrames( buf.Write(data) log.Debugf("cursor: processH2SessionFrames[%s]: buf total=%d", stream.ID(), buf.Len()) - // Process all complete frames - for { - currentBuf := buf.Bytes() - if len(currentBuf) == 0 { - break - } - flags, payload, consumed, ok := cursorproto.ParseConnectFrame(currentBuf) - if !ok { - // Log detailed info about why parsing failed - previewLen := min(20, len(currentBuf)) - log.Debugf("cursor: incomplete frame in buffer, waiting for more data (buf=%d bytes, first bytes: %x = %q)", len(currentBuf), currentBuf[:previewLen], string(currentBuf[:previewLen])) - break - } - buf.Next(consumed) - log.Debugf("cursor: parsed Connect frame flags=0x%02x payload=%d bytes consumed=%d", flags, len(payload), consumed) + // Process all complete frames + for { + currentBuf := buf.Bytes() + if len(currentBuf) == 0 { + break + } + flags, payload, consumed, ok := cursorproto.ParseConnectFrame(currentBuf) + if !ok { + // Log detailed info about why parsing failed + previewLen := min(20, len(currentBuf)) + log.Debugf("cursor: incomplete frame in buffer, waiting for more data (buf=%d bytes, first bytes: %x = %q)", len(currentBuf), currentBuf[:previewLen], string(currentBuf[:previewLen])) + break + } + buf.Next(consumed) + log.Debugf("cursor: parsed Connect frame flags=0x%02x payload=%d bytes consumed=%d", flags, len(payload), consumed) if flags&cursorproto.ConnectEndStreamFlag != 0 { if err := cursorproto.ParseConnectEndStream(payload); err != nil { @@ -1080,15 +1080,15 @@ func processH2SessionFrames( // --- OpenAI request parsing --- type parsedOpenAIRequest struct { - Model string - Messages []gjson.Result - Tools []gjson.Result - Stream bool + Model string + Messages []gjson.Result + Tools []gjson.Result + Stream bool SystemPrompt string - UserText string - Images []cursorproto.ImageData - Turns []cursorproto.TurnData - ToolResults []toolResultInfo + UserText string + Images []cursorproto.ImageData + Turns []cursorproto.TurnData + ToolResults []toolResultInfo } type toolResultInfo struct { diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 22d343fe..4887c7c1 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -16,9 +16,9 @@ import ( copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" diff --git a/internal/runtime/executor/gitlab_executor.go b/internal/runtime/executor/gitlab_executor.go index 88e7964a..9aa49f50 100644 --- a/internal/runtime/executor/gitlab_executor.go +++ b/internal/runtime/executor/gitlab_executor.go @@ -75,7 +75,7 @@ var gitLabAgenticCatalog = []gitLabCatalogModel{ } var gitLabModelAliases = map[string]string{ - "duo-chat-haiku-4-6": "duo-chat-haiku-4-5", + "duo-chat-haiku-4-6": "duo-chat-haiku-4-5", } func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor { diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index c63d1677..8c37b215 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -215,7 +215,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } body = preserveReasoningContentInMessages(body) - // Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour. + // Ensure tools array exists to avoid provider quirks observed in some upstreams. toolsResult := gjson.GetBytes(body, "tools") if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { body = ensureToolsArray(body) diff --git a/internal/runtime/executor/kiro_executor_test.go b/internal/runtime/executor/kiro_executor_test.go index 7a2819fd..73431e90 100644 --- a/internal/runtime/executor/kiro_executor_test.go +++ b/internal/runtime/executor/kiro_executor_test.go @@ -281,8 +281,8 @@ func TestGetAuthValue(t *testing.T) { expected: "attribute_value", }, { - name: "Both nil", - auth: &cliproxyauth.Auth{}, + name: "Both nil", + auth: &cliproxyauth.Auth{}, key: "test_key", expected: "", }, @@ -326,9 +326,9 @@ func TestGetAuthValue(t *testing.T) { func TestGetAccountKey(t *testing.T) { tests := []struct { - name string - auth *cliproxyauth.Auth - checkFn func(t *testing.T, result string) + name string + auth *cliproxyauth.Auth + checkFn func(t *testing.T, result string) }{ { name: "From client_id", diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go deleted file mode 100644 index 07ad0b3b..00000000 --- a/internal/runtime/executor/qwen_executor.go +++ /dev/null @@ -1,739 +0,0 @@ -package executor - -import ( - "bufio" - "bytes" - "context" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "sync" - "time" - - qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" - "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -const ( - qwenUserAgent = "QwenCode/0.14.2 (darwin; arm64)" - qwenRateLimitPerMin = 60 // 60 requests per minute per credential - qwenRateLimitWindow = time.Minute // sliding window duration -) - -var qwenDefaultSystemMessage = []byte(`{"role":"system","content":[{"type":"text","text":"","cache_control":{"type":"ephemeral"}}]}`) - -// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion. -var qwenQuotaCodes = map[string]struct{}{ - "insufficient_quota": {}, - "quota_exceeded": {}, -} - -// qwenRateLimiter tracks request timestamps per credential for rate limiting. -// Qwen has a limit of 60 requests per minute per account. -var qwenRateLimiter = struct { - sync.Mutex - requests map[string][]time.Time // authID -> request timestamps -}{ - requests: make(map[string][]time.Time), -} - -// redactAuthID returns a redacted version of the auth ID for safe logging. -// Keeps a small prefix/suffix to allow correlation across events. -func redactAuthID(id string) string { - if id == "" { - return "" - } - if len(id) <= 8 { - return id - } - return id[:4] + "..." + id[len(id)-4:] -} - -// checkQwenRateLimit checks if the credential has exceeded the rate limit. -// Returns nil if allowed, or a statusErr with retryAfter if rate limited. -func checkQwenRateLimit(authID string) error { - if authID == "" { - // Empty authID should not bypass rate limiting in production - // Use debug level to avoid log spam for certain auth flows - log.Debug("qwen rate limit check: empty authID, skipping rate limit") - return nil - } - - now := time.Now() - windowStart := now.Add(-qwenRateLimitWindow) - - qwenRateLimiter.Lock() - defer qwenRateLimiter.Unlock() - - // Get and filter timestamps within the window - timestamps := qwenRateLimiter.requests[authID] - var validTimestamps []time.Time - for _, ts := range timestamps { - if ts.After(windowStart) { - validTimestamps = append(validTimestamps, ts) - } - } - - // Always prune expired entries to prevent memory leak - // Delete empty entries, otherwise update with pruned slice - if len(validTimestamps) == 0 { - delete(qwenRateLimiter.requests, authID) - } - - // Check if rate limit exceeded - if len(validTimestamps) >= qwenRateLimitPerMin { - // Calculate when the oldest request will expire - oldestInWindow := validTimestamps[0] - retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now) - if retryAfter < time.Second { - retryAfter = time.Second - } - retryAfterSec := int(retryAfter.Seconds()) - return statusErr{ - code: http.StatusTooManyRequests, - msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec), - retryAfter: &retryAfter, - } - } - - // Record this request and update the map with pruned timestamps - validTimestamps = append(validTimestamps, now) - qwenRateLimiter.requests[authID] = validTimestamps - - return nil -} - -// isQwenQuotaError checks if the error response indicates a quota exceeded error. -// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted. -func isQwenQuotaError(body []byte) bool { - code := strings.ToLower(gjson.GetBytes(body, "error.code").String()) - errType := strings.ToLower(gjson.GetBytes(body, "error.type").String()) - - // Primary check: exact match on error.code or error.type (most reliable) - if _, ok := qwenQuotaCodes[code]; ok { - return true - } - if _, ok := qwenQuotaCodes[errType]; ok { - return true - } - - // Fallback: check message only if code/type don't match (less reliable) - msg := strings.ToLower(gjson.GetBytes(body, "error.message").String()) - if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") || - strings.Contains(msg, "free allocated quota exceeded") { - return true - } - - return false -} - -// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429. -// Returns the appropriate status code and retryAfter duration for statusErr. -// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives. -func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) { - errCode = httpCode - // Only check quota errors for expected status codes to avoid false positives - // Qwen returns 403 for quota errors, 429 for rate limits - if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) { - errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic - // Do not force an excessively long retry-after (e.g. until tomorrow), otherwise - // the global request-retry scheduler may skip retries due to max-retry-interval. - helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d)", httpCode, errCode) - } - return errCode, retryAfter -} - -func qwenDisableCooling(cfg *config.Config, auth *cliproxyauth.Auth) bool { - if auth != nil { - if override, ok := auth.DisableCoolingOverride(); ok { - return override - } - } - if cfg == nil { - return false - } - return cfg.DisableCooling -} - -func parseRetryAfterHeader(header http.Header, now time.Time) *time.Duration { - raw := strings.TrimSpace(header.Get("Retry-After")) - if raw == "" { - return nil - } - if seconds, err := strconv.Atoi(raw); err == nil { - if seconds <= 0 { - return nil - } - d := time.Duration(seconds) * time.Second - return &d - } - if at, err := http.ParseTime(raw); err == nil { - if !at.After(now) { - return nil - } - d := at.Sub(now) - return &d - } - return nil -} - -// ensureQwenSystemMessage ensures the request has a single system message at the beginning. -// It always injects the default system prompt and merges any user-provided system messages -// into the injected system message content to satisfy Qwen's strict message ordering rules. -func ensureQwenSystemMessage(payload []byte) ([]byte, error) { - isInjectedSystemPart := func(part gjson.Result) bool { - if !part.Exists() || !part.IsObject() { - return false - } - if !strings.EqualFold(part.Get("type").String(), "text") { - return false - } - if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") { - return false - } - text := part.Get("text").String() - return text == "" || text == "You are Qwen Code." - } - - defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content") - var systemParts []any - if defaultParts.Exists() && defaultParts.IsArray() { - for _, part := range defaultParts.Array() { - systemParts = append(systemParts, part.Value()) - } - } - if len(systemParts) == 0 { - systemParts = append(systemParts, map[string]any{ - "type": "text", - "text": "You are Qwen Code.", - "cache_control": map[string]any{ - "type": "ephemeral", - }, - }) - } - - appendSystemContent := func(content gjson.Result) { - makeTextPart := func(text string) map[string]any { - return map[string]any{ - "type": "text", - "text": text, - } - } - - if !content.Exists() || content.Type == gjson.Null { - return - } - if content.IsArray() { - for _, part := range content.Array() { - if part.Type == gjson.String { - systemParts = append(systemParts, makeTextPart(part.String())) - continue - } - if isInjectedSystemPart(part) { - continue - } - systemParts = append(systemParts, part.Value()) - } - return - } - if content.Type == gjson.String { - systemParts = append(systemParts, makeTextPart(content.String())) - return - } - if content.IsObject() { - if isInjectedSystemPart(content) { - return - } - systemParts = append(systemParts, content.Value()) - return - } - systemParts = append(systemParts, makeTextPart(content.String())) - } - - messages := gjson.GetBytes(payload, "messages") - var nonSystemMessages []any - if messages.Exists() && messages.IsArray() { - for _, msg := range messages.Array() { - if strings.EqualFold(msg.Get("role").String(), "system") { - appendSystemContent(msg.Get("content")) - continue - } - nonSystemMessages = append(nonSystemMessages, msg.Value()) - } - } - - newMessages := make([]any, 0, 1+len(nonSystemMessages)) - newMessages = append(newMessages, map[string]any{ - "role": "system", - "content": systemParts, - }) - newMessages = append(newMessages, nonSystemMessages...) - - updated, errSet := sjson.SetBytes(payload, "messages", newMessages) - if errSet != nil { - return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet) - } - return updated, nil -} - -// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions. -// If access token is unavailable, it falls back to legacy via ClientAdapter. -type QwenExecutor struct { - cfg *config.Config - refreshForImmediateRetry func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) -} - -func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cfg: cfg} } - -func (e *QwenExecutor) Identifier() string { return "qwen" } - -// PrepareRequest injects Qwen credentials into the outgoing HTTP request. -func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { - if req == nil { - return nil - } - token, _ := qwenCreds(auth) - if strings.TrimSpace(token) != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - return nil -} - -// HttpRequest injects Qwen credentials into the request and executes it. -func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { - if req == nil { - return nil, fmt.Errorf("qwen executor: request is nil") - } - if ctx == nil { - ctx = req.Context() - } - httpReq := req.WithContext(ctx) - if err := e.PrepareRequest(httpReq, auth); err != nil { - return nil, err - } - httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - return httpClient.Do(httpReq) -} - -func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - if opts.Alt == "responses/compact" { - return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} - } - - var authID string - if auth != nil { - authID = auth.ID - } - - baseModel := thinking.ParseSuffix(req.Model).ModelName - - reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.TrackFailure(ctx, &err) - - from := opts.SourceFormat - to := sdktranslator.FromString("openai") - originalPayloadSource := req.Payload - if len(opts.OriginalRequest) > 0 { - originalPayloadSource = opts.OriginalRequest - } - originalPayload := originalPayloadSource - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) - body, _ = sjson.SetBytes(body, "model", baseModel) - - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) - if err != nil { - return resp, err - } - - requestedModel := helps.PayloadRequestedModel(opts, req.Model) - body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) - body, err = ensureQwenSystemMessage(body) - if err != nil { - return resp, err - } - - for { - if errRate := checkQwenRateLimit(authID); errRate != nil { - helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) - return resp, errRate - } - - token, baseURL := qwenCreds(auth) - if baseURL == "" { - baseURL = "https://portal.qwen.ai/v1" - } - - url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if errReq != nil { - return resp, errReq - } - applyQwenHeaders(httpReq, token, false) - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - var authLabel, authType, authValue string - if auth != nil { - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: body, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - return resp, errDo - } - - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - helps.AppendAPIResponseChunk(ctx, e.cfg, b) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("qwen executor: close response body error: %v", errClose) - } - - errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) - if errCode == http.StatusTooManyRequests && retryAfter == nil { - retryAfter = parseRetryAfterHeader(httpResp.Header, time.Now()) - } - if errCode == http.StatusTooManyRequests && retryAfter == nil && qwenDisableCooling(e.cfg, auth) && isQwenQuotaError(b) { - defaultRetryAfter := time.Second - retryAfter = &defaultRetryAfter - } - helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - - err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter} - return resp, err - } - - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("qwen executor: close response body error: %v", errClose) - } - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - return resp, errRead - } - - helps.AppendAPIResponseChunk(ctx, e.cfg, data) - reporter.Publish(ctx, helps.ParseOpenAIUsage(data)) - - var param any - // Note: TranslateNonStream uses req.Model (original with suffix) to preserve - // the original model name in the response for client compatibility. - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} - return resp, nil - } -} - -func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { - if opts.Alt == "responses/compact" { - return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} - } - - var authID string - if auth != nil { - authID = auth.ID - } - - baseModel := thinking.ParseSuffix(req.Model).ModelName - - reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.TrackFailure(ctx, &err) - - from := opts.SourceFormat - to := sdktranslator.FromString("openai") - originalPayloadSource := req.Payload - if len(opts.OriginalRequest) > 0 { - originalPayloadSource = opts.OriginalRequest - } - originalPayload := originalPayloadSource - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) - body, _ = sjson.SetBytes(body, "model", baseModel) - - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) - if err != nil { - return nil, err - } - - // toolsResult := gjson.GetBytes(body, "tools") - // I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response. - // This will have no real consequences. It's just to scare Qwen3. - // if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() { - // body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`)) - // } - body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) - requestedModel := helps.PayloadRequestedModel(opts, req.Model) - body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) - body, err = ensureQwenSystemMessage(body) - if err != nil { - return nil, err - } - - for { - if errRate := checkQwenRateLimit(authID); errRate != nil { - helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) - return nil, errRate - } - - token, baseURL := qwenCreds(auth) - if baseURL == "" { - baseURL = "https://portal.qwen.ai/v1" - } - - url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if errReq != nil { - return nil, errReq - } - applyQwenHeaders(httpReq, token, true) - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - var authLabel, authType, authValue string - if auth != nil { - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: body, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - return nil, errDo - } - - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - helps.AppendAPIResponseChunk(ctx, e.cfg, b) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("qwen executor: close response body error: %v", errClose) - } - - errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) - if errCode == http.StatusTooManyRequests && retryAfter == nil { - retryAfter = parseRetryAfterHeader(httpResp.Header, time.Now()) - } - if errCode == http.StatusTooManyRequests && retryAfter == nil && qwenDisableCooling(e.cfg, auth) && isQwenQuotaError(b) { - defaultRetryAfter := time.Second - retryAfter = &defaultRetryAfter - } - helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - - err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter} - return nil, err - } - - out := make(chan cliproxyexecutor.StreamChunk) - go func() { - defer close(out) - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("qwen executor: close response body error: %v", errClose) - } - }() - scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 52_428_800) // 50MB - var param any - for scanner.Scan() { - line := scanner.Bytes() - helps.AppendAPIResponseChunk(ctx, e.cfg, line) - if detail, ok := helps.ParseOpenAIStreamUsage(line); ok { - reporter.Publish(ctx, detail) - } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) - for i := range chunks { - out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} - } - } - doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) - for i := range doneChunks { - out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]} - } - if errScan := scanner.Err(); errScan != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: errScan} - } - }() - return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil - } -} - -func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - baseModel := thinking.ParseSuffix(req.Model).ModelName - - from := opts.SourceFormat - to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) - - modelName := gjson.GetBytes(body, "model").String() - if strings.TrimSpace(modelName) == "" { - modelName = baseModel - } - - enc, err := helps.TokenizerForModel(modelName) - if err != nil { - return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err) - } - - count, err := helps.CountOpenAIChatTokens(enc, body) - if err != nil { - return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err) - } - - usageJSON := helps.BuildOpenAIUsageJSON(count) - translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) - return cliproxyexecutor.Response{Payload: translated}, nil -} - -func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - log.Debugf("qwen executor: refresh called") - if auth == nil { - return nil, fmt.Errorf("qwen executor: auth is nil") - } - // Expect refresh_token in metadata for OAuth-based accounts - var refreshToken string - if auth.Metadata != nil { - if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" { - refreshToken = v - } - } - if strings.TrimSpace(refreshToken) == "" { - // Nothing to refresh - return auth, nil - } - - svc := qwenauth.NewQwenAuth(e.cfg) - td, err := svc.RefreshTokens(ctx, refreshToken) - if err != nil { - return nil, err - } - if auth.Metadata == nil { - auth.Metadata = make(map[string]any) - } - auth.Metadata["access_token"] = td.AccessToken - if td.RefreshToken != "" { - auth.Metadata["refresh_token"] = td.RefreshToken - } - if td.ResourceURL != "" { - auth.Metadata["resource_url"] = td.ResourceURL - } - // Use "expired" for consistency with existing file format - auth.Metadata["expired"] = td.Expire - auth.Metadata["type"] = "qwen" - now := time.Now().Format(time.RFC3339) - auth.Metadata["last_refresh"] = now - return auth, nil -} - -func applyQwenHeaders(r *http.Request, token string, stream bool) { - r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0") - r.Header.Set("User-Agent", qwenUserAgent) - r.Header.Set("X-Stainless-Lang", "js") - r.Header.Set("Accept-Language", "*") - r.Header.Set("X-Dashscope-Cachecontrol", "enable") - r.Header.Set("X-Stainless-Os", "MacOS") - r.Header.Set("X-Dashscope-Authtype", "qwen-oauth") - r.Header.Set("X-Stainless-Arch", "arm64") - r.Header.Set("X-Stainless-Runtime", "node") - r.Header.Set("X-Stainless-Retry-Count", "0") - r.Header.Set("Accept-Encoding", "gzip, deflate") - r.Header.Set("Authorization", "Bearer "+token) - r.Header.Set("X-Stainless-Package-Version", "5.11.0") - r.Header.Set("Sec-Fetch-Mode", "cors") - r.Header.Set("Content-Type", "application/json") - r.Header.Set("Connection", "keep-alive") - r.Header.Set("X-Dashscope-Useragent", qwenUserAgent) - - if stream { - r.Header.Set("Accept", "text/event-stream") - return - } - r.Header.Set("Accept", "application/json") -} - -func normaliseQwenBaseURL(resourceURL string) string { - raw := strings.TrimSpace(resourceURL) - if raw == "" { - return "" - } - - normalized := raw - lower := strings.ToLower(normalized) - if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") { - normalized = "https://" + normalized - } - - normalized = strings.TrimRight(normalized, "/") - if !strings.HasSuffix(strings.ToLower(normalized), "/v1") { - normalized += "/v1" - } - - return normalized -} - -func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) { - if a == nil { - return "", "" - } - if a.Attributes != nil { - if v := a.Attributes["api_key"]; v != "" { - token = v - } - if v := a.Attributes["base_url"]; v != "" { - baseURL = v - } - } - if token == "" && a.Metadata != nil { - if v, ok := a.Metadata["access_token"].(string); ok { - token = v - } - if v, ok := a.Metadata["resource_url"].(string); ok { - baseURL = normaliseQwenBaseURL(v) - } - } - return -} diff --git a/internal/runtime/executor/qwen_executor_test.go b/internal/runtime/executor/qwen_executor_test.go deleted file mode 100644 index f19cc8ca..00000000 --- a/internal/runtime/executor/qwen_executor_test.go +++ /dev/null @@ -1,614 +0,0 @@ -package executor - -import ( - "context" - "net/http" - "net/http/httptest" - "sync/atomic" - "testing" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" - "github.com/tidwall/gjson" -) - -func TestQwenExecutorParseSuffix(t *testing.T) { - tests := []struct { - name string - model string - wantBase string - wantLevel string - }{ - {"no suffix", "qwen-max", "qwen-max", ""}, - {"with level suffix", "qwen-max(high)", "qwen-max", "high"}, - {"with budget suffix", "qwen-max(16384)", "qwen-max", "16384"}, - {"complex model name", "qwen-plus-latest(medium)", "qwen-plus-latest", "medium"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := thinking.ParseSuffix(tt.model) - if result.ModelName != tt.wantBase { - t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase) - } - }) - } -} - -func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) { - payload := []byte(`{ - "model": "qwen3.6-plus", - "stream": true, - "messages": [ - { "role": "system", "content": "ABCDEFG" }, - { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } - ] - }`) - - out, err := ensureQwenSystemMessage(payload) - if err != nil { - t.Fatalf("ensureQwenSystemMessage() error = %v", err) - } - - msgs := gjson.GetBytes(out, "messages").Array() - if len(msgs) != 2 { - t.Fatalf("messages length = %d, want 2", len(msgs)) - } - if msgs[0].Get("role").String() != "system" { - t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system") - } - parts := msgs[0].Get("content").Array() - if len(parts) != 2 { - t.Fatalf("messages[0].content length = %d, want 2", len(parts)) - } - if parts[0].Get("type").String() != "text" || parts[0].Get("cache_control.type").String() != "ephemeral" { - t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw) - } - if text := parts[0].Get("text").String(); text != "" && text != "You are Qwen Code." { - t.Fatalf("messages[0].content[0].text = %q, want empty string or default prompt", text) - } - if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" { - t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw) - } - if msgs[1].Get("role").String() != "user" { - t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user") - } -} - -func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) { - payload := []byte(`{ - "messages": [ - { "role": "system", "content": { "type": "text", "text": "ABCDEFG" } }, - { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } - ] - }`) - - out, err := ensureQwenSystemMessage(payload) - if err != nil { - t.Fatalf("ensureQwenSystemMessage() error = %v", err) - } - - msgs := gjson.GetBytes(out, "messages").Array() - if len(msgs) != 2 { - t.Fatalf("messages length = %d, want 2", len(msgs)) - } - parts := msgs[0].Get("content").Array() - if len(parts) != 2 { - t.Fatalf("messages[0].content length = %d, want 2", len(parts)) - } - if parts[1].Get("text").String() != "ABCDEFG" { - t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG") - } -} - -func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) { - payload := []byte(`{ - "messages": [ - { "role": "user", "content": [ { "type": "text", "text": "你好" } ] } - ] - }`) - - out, err := ensureQwenSystemMessage(payload) - if err != nil { - t.Fatalf("ensureQwenSystemMessage() error = %v", err) - } - - msgs := gjson.GetBytes(out, "messages").Array() - if len(msgs) != 2 { - t.Fatalf("messages length = %d, want 2", len(msgs)) - } - if msgs[0].Get("role").String() != "system" { - t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system") - } - if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 { - t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw) - } - if msgs[1].Get("role").String() != "user" { - t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user") - } -} - -func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) { - payload := []byte(`{ - "messages": [ - { "role": "system", "content": "A" }, - { "role": "user", "content": [ { "type": "text", "text": "hi" } ] }, - { "role": "system", "content": "B" } - ] - }`) - - out, err := ensureQwenSystemMessage(payload) - if err != nil { - t.Fatalf("ensureQwenSystemMessage() error = %v", err) - } - - msgs := gjson.GetBytes(out, "messages").Array() - if len(msgs) != 2 { - t.Fatalf("messages length = %d, want 2", len(msgs)) - } - parts := msgs[0].Get("content").Array() - if len(parts) != 3 { - t.Fatalf("messages[0].content length = %d, want 3", len(parts)) - } - if parts[1].Get("text").String() != "A" { - t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A") - } - if parts[2].Get("text").String() != "B" { - t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B") - } -} - -func TestWrapQwenError_InsufficientQuotaDoesNotSetRetryAfter(t *testing.T) { - body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`) - code, retryAfter := wrapQwenError(context.Background(), http.StatusTooManyRequests, body) - if code != http.StatusTooManyRequests { - t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests) - } - if retryAfter != nil { - t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter) - } -} - -func TestWrapQwenError_Maps403QuotaTo429WithoutRetryAfter(t *testing.T) { - body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`) - code, retryAfter := wrapQwenError(context.Background(), http.StatusForbidden, body) - if code != http.StatusTooManyRequests { - t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests) - } - if retryAfter != nil { - t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter) - } -} - -func TestQwenCreds_NormalizesResourceURL(t *testing.T) { - tests := []struct { - name string - resourceURL string - wantBaseURL string - }{ - {"host only", "portal.qwen.ai", "https://portal.qwen.ai/v1"}, - {"scheme no v1", "https://portal.qwen.ai", "https://portal.qwen.ai/v1"}, - {"scheme with v1", "https://portal.qwen.ai/v1", "https://portal.qwen.ai/v1"}, - {"scheme with v1 slash", "https://portal.qwen.ai/v1/", "https://portal.qwen.ai/v1"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - auth := &cliproxyauth.Auth{ - Metadata: map[string]any{ - "access_token": "test-token", - "resource_url": tt.resourceURL, - }, - } - - token, baseURL := qwenCreds(auth) - if token != "test-token" { - t.Fatalf("qwenCreds token = %q, want %q", token, "test-token") - } - if baseURL != tt.wantBaseURL { - t.Fatalf("qwenCreds baseURL = %q, want %q", baseURL, tt.wantBaseURL) - } - }) - } -} - -func TestQwenExecutorExecute_429DoesNotRefreshOrRetry(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - switch r.Header.Get("Authorization") { - case "Bearer old-token": - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`)) - return - case "Bearer new-token": - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"chatcmpl-test","object":"chat.completion","created":1,"model":"qwen-max","choices":[{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)) - return - default: - w.WriteHeader(http.StatusUnauthorized) - return - } - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "old-token", - "refresh_token": "refresh-token", - }, - } - - var refresherCalls int32 - exec.refreshForImmediateRetry = func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - atomic.AddInt32(&refresherCalls, 1) - refreshed := auth.Clone() - if refreshed.Metadata == nil { - refreshed.Metadata = make(map[string]any) - } - refreshed.Metadata["access_token"] = "new-token" - refreshed.Metadata["refresh_token"] = "refresh-token-2" - return refreshed, nil - } - ctx := context.Background() - - _, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("Execute() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("Execute() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } - if atomic.LoadInt32(&refresherCalls) != 0 { - t.Fatalf("refresher calls = %d, want 0", atomic.LoadInt32(&refresherCalls)) - } -} - -func TestQwenExecutorExecuteStream_429DoesNotRefreshOrRetry(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - switch r.Header.Get("Authorization") { - case "Bearer old-token": - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`)) - return - case "Bearer new-token": - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-test\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"qwen-max\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n")) - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() - } - return - default: - w.WriteHeader(http.StatusUnauthorized) - return - } - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "old-token", - "refresh_token": "refresh-token", - }, - } - - var refresherCalls int32 - exec.refreshForImmediateRetry = func(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - atomic.AddInt32(&refresherCalls, 1) - refreshed := auth.Clone() - if refreshed.Metadata == nil { - refreshed.Metadata = make(map[string]any) - } - refreshed.Metadata["access_token"] = "new-token" - refreshed.Metadata["refresh_token"] = "refresh-token-2" - return refreshed, nil - } - ctx := context.Background() - - _, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("ExecuteStream() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("ExecuteStream() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } - if atomic.LoadInt32(&refresherCalls) != 0 { - t.Fatalf("refresher calls = %d, want 0", atomic.LoadInt32(&refresherCalls)) - } -} - -func TestQwenExecutorExecute_429RetryAfterHeaderPropagatesToStatusErr(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Retry-After", "2") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"rate_limit_exceeded","message":"rate limited","type":"rate_limit_exceeded"}}`)) - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "test-token", - }, - } - ctx := context.Background() - - _, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("Execute() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("Execute() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if status.RetryAfter() == nil { - t.Fatalf("Execute() RetryAfter is nil, want non-nil") - } - if got := *status.RetryAfter(); got != 2*time.Second { - t.Fatalf("Execute() RetryAfter = %v, want %v", got, 2*time.Second) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } -} - -func TestQwenExecutorExecuteStream_429RetryAfterHeaderPropagatesToStatusErr(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Retry-After", "2") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"rate_limit_exceeded","message":"rate limited","type":"rate_limit_exceeded"}}`)) - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "test-token", - }, - } - ctx := context.Background() - - _, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("ExecuteStream() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("ExecuteStream() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if status.RetryAfter() == nil { - t.Fatalf("ExecuteStream() RetryAfter is nil, want non-nil") - } - if got := *status.RetryAfter(); got != 2*time.Second { - t.Fatalf("ExecuteStream() RetryAfter = %v, want %v", got, 2*time.Second) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } -} - -func TestQwenExecutorExecute_429QuotaExhausted_DisableCoolingSetsDefaultRetryAfter(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`)) - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{DisableCooling: true}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "test-token", - }, - } - ctx := context.Background() - - _, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("Execute() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("Execute() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("Execute() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if status.RetryAfter() == nil { - t.Fatalf("Execute() RetryAfter is nil, want non-nil") - } - if got := *status.RetryAfter(); got != time.Second { - t.Fatalf("Execute() RetryAfter = %v, want %v", got, time.Second) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } -} - -func TestQwenExecutorExecuteStream_429QuotaExhausted_DisableCoolingSetsDefaultRetryAfter(t *testing.T) { - qwenRateLimiter.Lock() - qwenRateLimiter.requests = make(map[string][]time.Time) - qwenRateLimiter.Unlock() - - var calls int32 - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&calls, 1) - if r.URL.Path != "/v1/chat/completions" { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"code":"quota_exceeded","message":"quota exceeded","type":"quota_exceeded"}}`)) - })) - defer srv.Close() - - exec := NewQwenExecutor(&config.Config{DisableCooling: true}) - auth := &cliproxyauth.Auth{ - ID: "auth-test", - Provider: "qwen", - Attributes: map[string]string{ - "base_url": srv.URL + "/v1", - }, - Metadata: map[string]any{ - "access_token": "test-token", - }, - } - ctx := context.Background() - - _, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{ - Model: "qwen-max", - Payload: []byte(`{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"hi"}]}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FromString("openai"), - }) - if err == nil { - t.Fatalf("ExecuteStream() expected error, got nil") - } - status, ok := err.(statusErr) - if !ok { - t.Fatalf("ExecuteStream() error type = %T, want statusErr", err) - } - if status.StatusCode() != http.StatusTooManyRequests { - t.Fatalf("ExecuteStream() status code = %d, want %d", status.StatusCode(), http.StatusTooManyRequests) - } - if status.RetryAfter() == nil { - t.Fatalf("ExecuteStream() RetryAfter is nil, want non-nil") - } - if got := *status.RetryAfter(); got != time.Second { - t.Fatalf("ExecuteStream() RetryAfter = %v, want %v", got, time.Second) - } - if atomic.LoadInt32(&calls) != 1 { - t.Fatalf("upstream calls = %d, want 1", atomic.LoadInt32(&calls)) - } -} diff --git a/internal/thinking/provider/iflow/apply.go b/internal/thinking/provider/iflow/apply.go index 35d13f59..082cacff 100644 --- a/internal/thinking/provider/iflow/apply.go +++ b/internal/thinking/provider/iflow/apply.go @@ -154,7 +154,7 @@ func isEnableThinkingModel(modelID string) bool { } id := strings.ToLower(modelID) switch id { - case "qwen3-max-preview", "deepseek-v3.2", "deepseek-v3.1": + case "deepseek-v3.2", "deepseek-v3.1": return true default: return false diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 05b724c9..8ae69648 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -101,6 +101,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ systemTypePromptResult := systemPromptResult.Get("type") if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { systemPrompt := systemPromptResult.Get("text").String() + if strings.HasPrefix(systemPrompt, "x-anthropic-billing-header:") { + continue + } partJSON := []byte(`{}`) if systemPrompt != "" { partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt) @@ -170,9 +173,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ continue } - // Valid signature, send as thought block - // Always include "text" field — Google Antigravity API requires it - // even for redacted thinking where the text is empty. + // Drop empty-text thinking blocks (redacted thinking from Claude Max). + // Antigravity wraps empty text into a prompt-caching-scope object that + // omits the required inner "thinking" field, causing: + // 400 "messages.N.content.0.thinking.thinking: Field required" + if thinkingText == "" { + continue + } + + // Valid signature with content, send as thought block. partJSON := []byte(`{}`) partJSON, _ = sjson.SetBytes(partJSON, "thought", true) partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 681b2de5..919e2906 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -468,11 +468,7 @@ func TestValidateBypassMode_HandlesWhitespace(t *testing.T) { func TestValidateBypassMode_RejectsOversizedSignature(t *testing.T) { t.Parallel() - payload := append([]byte{0x12}, bytes.Repeat([]byte{0x34}, maxBypassSignatureLen)...) - sig := base64.StdEncoding.EncodeToString(payload) - if len(sig) <= maxBypassSignatureLen { - t.Fatalf("test setup: signature should exceed max length, got %d", len(sig)) - } + sig := strings.Repeat("A", maxBypassSignatureLen+1) inputJSON := []byte(`{ "messages": [{"role": "assistant", "content": [ @@ -489,6 +485,33 @@ func TestValidateBypassMode_RejectsOversizedSignature(t *testing.T) { } } +func TestValidateBypassMode_StrictAcceptsSignatureBetween16KiBAnd32MiB(t *testing.T) { + previous := cache.SignatureBypassStrictMode() + cache.SetSignatureBypassStrictMode(true) + t.Cleanup(func() { + cache.SetSignatureBypassStrictMode(previous) + }) + + payload := buildClaudeSignaturePayload(t, 12, uint64Ptr(2), strings.Repeat("m", 20000), true) + sig := base64.StdEncoding.EncodeToString(payload) + if len(sig) <= 1<<14 { + t.Fatalf("test setup: signature should exceed previous 16KiB guardrail, got %d", len(sig)) + } + if len(sig) > maxBypassSignatureLen { + t.Fatalf("test setup: signature should remain within new max length, got %d", len(sig)) + } + + inputJSON := []byte(`{ + "messages": [{"role": "assistant", "content": [ + {"type": "thinking", "thinking": "t", "signature": "` + sig + `"} + ]}] + }`) + + if err := ValidateClaudeBypassSignatures(inputJSON); err != nil { + t.Fatalf("expected strict mode to accept signature below 32MiB max, got: %v", err) + } +} + func TestResolveBypassModeSignature_TrimsWhitespace(t *testing.T) { previous := cache.SignatureCacheEnabled() cache.SetSignatureCacheEnabled(false) @@ -2158,6 +2181,225 @@ func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *te } } +func TestConvertClaudeRequestToAntigravity_BypassMode_DropsRedactedThinkingBlocks(t *testing.T) { + cache.ClearSignatureCache("") + previous := cache.SignatureCacheEnabled() + cache.SetSignatureCacheEnabled(false) + t.Cleanup(func() { + cache.SetSignatureCacheEnabled(previous) + cache.ClearSignatureCache("") + }) + + validSignature := testAnthropicNativeSignature(t) + + inputJSON := []byte(`{ + "model": "claude-opus-4-6", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + }, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "", "signature": "` + validSignature + `"}, + {"type": "text", "text": "I can help with that."} + ] + }, + { + "role": "user", + "content": [{"type": "text", "text": "Follow up question"}] + } + ], + "thinking": {"type": "enabled", "budget_tokens": 10000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false) + + assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array() + if len(assistantParts) != 1 { + t.Fatalf("Expected 1 part (redacted thinking dropped), got %d: %s", + len(assistantParts), gjson.GetBytes(output, "request.contents.1.parts").Raw) + } + if assistantParts[0].Get("thought").Bool() { + t.Fatal("Redacted thinking block with empty text should be dropped") + } + if assistantParts[0].Get("text").String() != "I can help with that." { + t.Fatalf("Expected text part preserved, got: %s", assistantParts[0].Raw) + } +} + +func TestConvertClaudeRequestToAntigravity_BypassMode_DropsWrappedRedactedThinking(t *testing.T) { + cache.ClearSignatureCache("") + previous := cache.SignatureCacheEnabled() + cache.SetSignatureCacheEnabled(false) + t.Cleanup(func() { + cache.SetSignatureCacheEnabled(previous) + cache.ClearSignatureCache("") + }) + + validSignature := testAnthropicNativeSignature(t) + + inputJSON := []byte(`{ + "model": "claude-sonnet-4-6", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Test user message"}] + }, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": {"cache_control": {"type": "ephemeral"}}, "signature": "` + validSignature + `"}, + {"type": "text", "text": "Answer"} + ] + }, + { + "role": "user", + "content": [{"type": "text", "text": "Follow up"}] + } + ], + "thinking": {"type": "enabled", "budget_tokens": 8000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-6", inputJSON, false) + + assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array() + if len(assistantParts) != 1 { + t.Fatalf("Expected 1 part (wrapped redacted thinking dropped), got %d: %s", + len(assistantParts), gjson.GetBytes(output, "request.contents.1.parts").Raw) + } + if assistantParts[0].Get("text").String() != "Answer" { + t.Fatalf("Expected text part preserved, got: %s", assistantParts[0].Raw) + } +} + +func TestConvertClaudeRequestToAntigravity_BypassMode_KeepsNonEmptyThinking(t *testing.T) { + cache.ClearSignatureCache("") + previous := cache.SignatureCacheEnabled() + cache.SetSignatureCacheEnabled(false) + t.Cleanup(func() { + cache.SetSignatureCacheEnabled(previous) + cache.ClearSignatureCache("") + }) + + validSignature := testAnthropicNativeSignature(t) + + inputJSON := []byte(`{ + "model": "claude-opus-4-6", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + }, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me reason about this carefully...", "signature": "` + validSignature + `"}, + {"type": "text", "text": "Here is my answer."} + ] + } + ], + "thinking": {"type": "enabled", "budget_tokens": 10000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false) + + assistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array() + if len(assistantParts) != 2 { + t.Fatalf("Expected 2 parts (thinking + text), got %d", len(assistantParts)) + } + if !assistantParts[0].Get("thought").Bool() { + t.Fatal("First part should be a thought block") + } + if assistantParts[0].Get("text").String() != "Let me reason about this carefully..." { + t.Fatalf("Thinking text mismatch, got: %s", assistantParts[0].Get("text").String()) + } + if assistantParts[1].Get("text").String() != "Here is my answer." { + t.Fatalf("Text part mismatch, got: %s", assistantParts[1].Raw) + } +} + +func TestConvertClaudeRequestToAntigravity_BypassMode_MultiTurnRedactedThinking(t *testing.T) { + cache.ClearSignatureCache("") + previous := cache.SignatureCacheEnabled() + cache.SetSignatureCacheEnabled(false) + t.Cleanup(func() { + cache.SetSignatureCacheEnabled(previous) + cache.ClearSignatureCache("") + }) + + sig := testAnthropicNativeSignature(t) + + inputJSON := []byte(`{ + "model": "claude-opus-4-6", + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "First question"}]}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "", "signature": "` + sig + `"}, + {"type": "text", "text": "First answer"}, + {"type": "tool_use", "id": "Bash-123-456", "name": "Bash", "input": {"command": "ls"}} + ] + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "Bash-123-456", "content": "file1.txt\nfile2.txt"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "", "signature": "` + sig + `"}, + {"type": "text", "text": "Here are the files."} + ] + }, + {"role": "user", "content": [{"type": "text", "text": "Thanks"}]} + ], + "thinking": {"type": "enabled", "budget_tokens": 10000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-opus-4-6", inputJSON, false) + + if !gjson.ValidBytes(output) { + t.Fatalf("Output is not valid JSON: %s", string(output)) + } + + firstAssistantParts := gjson.GetBytes(output, "request.contents.1.parts").Array() + for _, p := range firstAssistantParts { + if p.Get("thought").Bool() { + t.Fatal("Redacted thinking should be dropped from first assistant message") + } + } + hasText := false + hasFC := false + for _, p := range firstAssistantParts { + if p.Get("text").String() == "First answer" { + hasText = true + } + if p.Get("functionCall").Exists() { + hasFC = true + } + } + if !hasText || !hasFC { + t.Fatalf("First assistant should have text + functionCall, got: %s", + gjson.GetBytes(output, "request.contents.1.parts").Raw) + } + + secondAssistantParts := gjson.GetBytes(output, "request.contents.3.parts").Array() + for _, p := range secondAssistantParts { + if p.Get("thought").Bool() { + t.Fatal("Redacted thinking should be dropped from second assistant message") + } + } + if len(secondAssistantParts) != 1 || secondAssistantParts[0].Get("text").String() != "Here are the files." { + t.Fatalf("Second assistant should have only text part, got: %s", + gjson.GetBytes(output, "request.contents.3.parts").Raw) + } +} + func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) { // When tools + thinking but no system instruction, should create one with hint inputJSON := []byte(`{ diff --git a/internal/translator/antigravity/claude/signature_validation.go b/internal/translator/antigravity/claude/signature_validation.go index e1b9f542..63203abd 100644 --- a/internal/translator/antigravity/claude/signature_validation.go +++ b/internal/translator/antigravity/claude/signature_validation.go @@ -55,10 +55,11 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "google.golang.org/protobuf/encoding/protowire" ) -const maxBypassSignatureLen = 8192 +const maxBypassSignatureLen = 32 * 1024 * 1024 type claudeSignatureTree struct { EncodingLayers int @@ -72,6 +73,62 @@ type claudeSignatureTree struct { HasField7 bool } +// StripInvalidSignatureThinkingBlocks removes thinking blocks whose signatures +// are empty or not valid Claude format (must start with 'E' or 'R' after +// stripping any cache prefix). These come from proxy-generated responses +// (Antigravity/Gemini) where no real Claude signature exists. +func StripEmptySignatureThinkingBlocks(payload []byte) []byte { + messages := gjson.GetBytes(payload, "messages") + if !messages.IsArray() { + return payload + } + modified := false + for i, msg := range messages.Array() { + content := msg.Get("content") + if !content.IsArray() { + continue + } + var kept []string + stripped := false + for _, part := range content.Array() { + if part.Get("type").String() == "thinking" && !hasValidClaudeSignature(part.Get("signature").String()) { + stripped = true + continue + } + kept = append(kept, part.Raw) + } + if stripped { + modified = true + if len(kept) == 0 { + payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("[]")) + } else { + payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("["+strings.Join(kept, ",")+"]")) + } + } + } + if !modified { + return payload + } + return payload +} + +// hasValidClaudeSignature returns true if sig looks like a real Claude thinking +// signature: non-empty and starts with 'E' or 'R' (after stripping optional +// cache prefix like "modelGroup#"). +func hasValidClaudeSignature(sig string) bool { + sig = strings.TrimSpace(sig) + if sig == "" { + return false + } + if idx := strings.IndexByte(sig, '#'); idx >= 0 { + sig = strings.TrimSpace(sig[idx+1:]) + } + if sig == "" { + return false + } + return sig[0] == 'E' || sig[0] == 'R' +} + func ValidateClaudeBypassSignatures(inputRawJSON []byte) error { messages := gjson.GetBytes(inputRawJSON, "messages") if !messages.IsArray() { diff --git a/internal/translator/kiro/common/utils.go b/internal/translator/kiro/common/utils.go index f5f5788a..4c7c7340 100644 --- a/internal/translator/kiro/common/utils.go +++ b/internal/translator/kiro/common/utils.go @@ -13,4 +13,4 @@ func GetString(m map[string]interface{}, key string) string { // GetStringValue is an alias for GetString for backward compatibility. func GetStringValue(m map[string]interface{}, key string) string { return GetString(m, key) -} \ No newline at end of file +} diff --git a/internal/translator/kiro/openai/init.go b/internal/translator/kiro/openai/init.go index 653eed45..d43b21a7 100644 --- a/internal/translator/kiro/openai/init.go +++ b/internal/translator/kiro/openai/init.go @@ -17,4 +17,4 @@ func init() { NonStream: ConvertKiroNonStreamToOpenAI, }, ) -} \ No newline at end of file +} diff --git a/internal/translator/kiro/openai/kiro_openai_response.go b/internal/translator/kiro/openai/kiro_openai_response.go index edc70ad8..7d085de0 100644 --- a/internal/translator/kiro/openai/kiro_openai_response.go +++ b/internal/translator/kiro/openai/kiro_openai_response.go @@ -274,4 +274,4 @@ func min(a, b int) int { return a } return b -} \ No newline at end of file +} diff --git a/internal/translator/kiro/openai/kiro_openai_stream.go b/internal/translator/kiro/openai/kiro_openai_stream.go index e72d970e..484a94ee 100644 --- a/internal/translator/kiro/openai/kiro_openai_stream.go +++ b/internal/translator/kiro/openai/kiro_openai_stream.go @@ -209,4 +209,4 @@ func NewThinkingTagState() *ThinkingTagState { PendingStartChars: 0, PendingEndChars: 0, } -} \ No newline at end of file +} diff --git a/internal/tui/oauth_tab.go b/internal/tui/oauth_tab.go index 3989e3d8..1df045ac 100644 --- a/internal/tui/oauth_tab.go +++ b/internal/tui/oauth_tab.go @@ -23,7 +23,6 @@ var oauthProviders = []oauthProvider{ {"Claude (Anthropic)", "anthropic-auth-url", "🟧"}, {"Codex (OpenAI)", "codex-auth-url", "🟩"}, {"Antigravity", "antigravity-auth-url", "🟪"}, - {"Qwen", "qwen-auth-url", "🟨"}, {"Kimi", "kimi-auth-url", "🟫"}, {"IFlow", "iflow-auth-url", "⬜"}, } @@ -280,8 +279,6 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd { providerKey = "codex" case "antigravity-auth-url": providerKey = "antigravity" - case "qwen-auth-url": - providerKey = "qwen" case "kimi-auth-url": providerKey = "kimi" case "iflow-auth-url": diff --git a/internal/util/provider.go b/internal/util/provider.go index a9c1de0c..50fa002e 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -21,7 +21,6 @@ import ( // - "gemini" for Google's Gemini family // - "codex" for OpenAI GPT-compatible providers // - "claude" for Anthropic models -// - "qwen" for Alibaba's Qwen models // - "openai-compatibility" for external OpenAI-compatible providers // // Parameters: diff --git a/internal/watcher/clients.go b/internal/watcher/clients.go index 60ff6197..7746f4ad 100644 --- a/internal/watcher/clients.go +++ b/internal/watcher/clients.go @@ -8,7 +8,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "io/fs" "os" "path/filepath" "strings" @@ -85,14 +84,22 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir) } else if resolvedAuthDir != "" { - _ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return nil - } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { - if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 { + entries, errReadDir := os.ReadDir(resolvedAuthDir) + if errReadDir != nil { + log.Errorf("failed to read auth directory for hash cache: %v", errReadDir) + } else { + for _, entry := range entries { + if entry == nil || entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(strings.ToLower(name), ".json") { + continue + } + fullPath := filepath.Join(resolvedAuthDir, name) + if data, errReadFile := os.ReadFile(fullPath); errReadFile == nil && len(data) > 0 { sum := sha256.Sum256(data) - normalizedPath := w.normalizeAuthPath(path) + normalizedPath := w.normalizeAuthPath(fullPath) w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:]) // Parse and cache auth content for future diff comparisons (debug only). if cacheAuthContents { @@ -107,15 +114,14 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string Now: time.Now(), IDGenerator: synthesizer.NewStableIDGenerator(), } - if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 { + if generated := synthesizer.SynthesizeAuthFile(ctx, fullPath, data); len(generated) > 0 { if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 { w.fileAuthsByPath[normalizedPath] = authIDSet(pathAuths) } } } } - return nil - }) + } } w.clientsMutex.Unlock() } @@ -306,23 +312,25 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int { return 0 } - errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error { - if err != nil { - log.Debugf("error accessing path %s: %v", path, err) - return err + entries, errReadDir := os.ReadDir(authDir) + if errReadDir != nil { + log.Errorf("error reading auth directory: %v", errReadDir) + return 0 + } + for _, entry := range entries { + if entry == nil || entry.IsDir() { + continue } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { - authFileCount++ - log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path)) - if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 { - successfulAuthCount++ - } + name := entry.Name() + if !strings.HasSuffix(strings.ToLower(name), ".json") { + continue + } + authFileCount++ + log.Debugf("processing auth file %d: %s", authFileCount, name) + fullPath := filepath.Join(authDir, name) + if data, errReadFile := os.ReadFile(fullPath); errReadFile == nil && len(data) > 0 { + successfulAuthCount++ } - return nil - }) - - if errWalk != nil { - log.Errorf("error walking auth directory: %v", errWalk) } log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount) return authFileCount diff --git a/internal/watcher/events.go b/internal/watcher/events.go index fb96ad2a..0869c685 100644 --- a/internal/watcher/events.go +++ b/internal/watcher/events.go @@ -96,7 +96,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { normalizedAuthDir := w.normalizeAuthPath(w.authDir) isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0 authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename - isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 + isAuthJSON := filepath.Dir(normalizedName) == normalizedAuthDir && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0 if !isConfigEvent && !isAuthJSON && !isKiroIDEToken { // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index cc371325..168a5263 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -14,7 +14,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" @@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool { func requestExecutionMetadata(ctx context.Context) map[string]any { // Idempotency-Key is an optional client-supplied header used to correlate retries. - // It is forwarded as execution metadata; when absent we generate a UUID. + // Only include it if the client explicitly provides it. key := "" if ctx != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { @@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { } } if key == "" { - key = uuid.NewString() + return make(map[string]any) } meta := map[string]any{idempotencyKeyMetadataKey: key} diff --git a/sdk/api/management.go b/sdk/api/management.go index 6fd3b709..b1a7f8e3 100644 --- a/sdk/api/management.go +++ b/sdk/api/management.go @@ -17,7 +17,6 @@ type ManagementTokenRequester interface { RequestGeminiCLIToken(*gin.Context) RequestCodexToken(*gin.Context) RequestAntigravityToken(*gin.Context) - RequestQwenToken(*gin.Context) RequestKimiToken(*gin.Context) RequestIFlowToken(*gin.Context) RequestIFlowCookieToken(*gin.Context) @@ -52,10 +51,6 @@ func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) { m.handler.RequestAntigravityToken(c) } -func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) { - m.handler.RequestQwenToken(c) -} - func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) { m.handler.RequestKimiToken(c) } diff --git a/sdk/auth/kilo.go b/sdk/auth/kilo.go index 7e98f7c4..ee947fdd 100644 --- a/sdk/auth/kilo.go +++ b/sdk/auth/kilo.go @@ -39,7 +39,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts } kilocodeAuth := kilo.NewKiloAuth() - + fmt.Println("Initiating Kilo device authentication...") resp, err := kilocodeAuth.InitiateDeviceFlow(ctx) if err != nil { @@ -48,7 +48,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Printf("Please visit: %s\n", resp.VerificationURL) fmt.Printf("And enter code: %s\n", resp.Code) - + fmt.Println("Waiting for authorization...") status, err := kilocodeAuth.PollForToken(ctx, resp.Code) if err != nil { @@ -68,7 +68,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts for i, org := range profile.Orgs { fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID) } - + if opts.Prompt != nil { input, err := opts.Prompt("Enter the number of the organization: ") if err != nil { @@ -108,7 +108,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts metadata := map[string]any{ "email": status.UserEmail, "organization_id": orgID, - "model": defaults.Model, + "model": defaults.Model, } return &coreauth.Auth{ diff --git a/sdk/auth/qwen.go b/sdk/auth/qwen.go deleted file mode 100644 index d891021a..00000000 --- a/sdk/auth/qwen.go +++ /dev/null @@ -1,113 +0,0 @@ -package auth - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" - "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" - // legacy client removed - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - log "github.com/sirupsen/logrus" -) - -// QwenAuthenticator implements the device flow login for Qwen accounts. -type QwenAuthenticator struct{} - -// NewQwenAuthenticator constructs a Qwen authenticator. -func NewQwenAuthenticator() *QwenAuthenticator { - return &QwenAuthenticator{} -} - -func (a *QwenAuthenticator) Provider() string { - return "qwen" -} - -func (a *QwenAuthenticator) RefreshLead() *time.Duration { - return new(20 * time.Minute) -} - -func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("cliproxy auth: configuration is required") - } - if ctx == nil { - ctx = context.Background() - } - if opts == nil { - opts = &LoginOptions{} - } - - authSvc := qwen.NewQwenAuth(cfg) - - deviceFlow, err := authSvc.InitiateDeviceFlow(ctx) - if err != nil { - return nil, fmt.Errorf("qwen device flow initiation failed: %w", err) - } - - authURL := deviceFlow.VerificationURIComplete - - if !opts.NoBrowser { - fmt.Println("Opening browser for Qwen authentication") - if !browser.IsAvailable() { - log.Warn("No browser available; please open the URL manually") - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } else if err = browser.OpenURL(authURL); err != nil { - log.Warnf("Failed to open browser automatically: %v", err) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - } else { - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - - fmt.Println("Waiting for Qwen authentication...") - - tokenData, err := authSvc.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) - if err != nil { - return nil, fmt.Errorf("qwen authentication failed: %w", err) - } - - tokenStorage := authSvc.CreateTokenStorage(tokenData) - - email := "" - if opts.Metadata != nil { - email = opts.Metadata["email"] - if email == "" { - email = opts.Metadata["alias"] - } - } - - if email == "" && opts.Prompt != nil { - email, err = opts.Prompt("Please input your email address or alias for Qwen:") - if err != nil { - return nil, err - } - } - - email = strings.TrimSpace(email) - if email == "" { - return nil, &EmailRequiredError{Prompt: "Please provide an email address or alias for Qwen."} - } - - tokenStorage.Email = email - - // no legacy client construction - - fileName := fmt.Sprintf("qwen-%s.json", tokenStorage.Email) - metadata := map[string]any{ - "email": tokenStorage.Email, - } - - fmt.Println("Qwen authentication successful") - - return &coreauth.Auth{ - ID: fileName, - Provider: a.Provider(), - FileName: fileName, - Storage: tokenStorage, - Metadata: metadata, - }, nil -} diff --git a/sdk/auth/qwen_refresh_lead_test.go b/sdk/auth/qwen_refresh_lead_test.go deleted file mode 100644 index 56f41fc0..00000000 --- a/sdk/auth/qwen_refresh_lead_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package auth - -import ( - "testing" - "time" -) - -func TestQwenAuthenticator_RefreshLeadIsSane(t *testing.T) { - lead := NewQwenAuthenticator().RefreshLead() - if lead == nil { - t.Fatal("RefreshLead() = nil, want non-nil") - } - if *lead <= 0 { - t.Fatalf("RefreshLead() = %s, want > 0", *lead) - } - if *lead > 30*time.Minute { - t.Fatalf("RefreshLead() = %s, want <= %s", *lead, 30*time.Minute) - } -} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index 59c58bee..f3419de5 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -9,7 +9,6 @@ import ( func init() { registerRefreshLead("codex", func() Authenticator { return NewCodexAuthenticator() }) registerRefreshLead("claude", func() Authenticator { return NewClaudeAuthenticator() }) - registerRefreshLead("qwen", func() Authenticator { return NewQwenAuthenticator() }) registerRefreshLead("iflow", func() Authenticator { return NewIFlowAuthenticator() }) registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) diff --git a/sdk/cliproxy/auth/auto_refresh_loop.go b/sdk/cliproxy/auth/auto_refresh_loop.go new file mode 100644 index 00000000..9767ee58 --- /dev/null +++ b/sdk/cliproxy/auth/auto_refresh_loop.go @@ -0,0 +1,453 @@ +package auth + +import ( + "container/heap" + "context" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type authAutoRefreshLoop struct { + manager *Manager + interval time.Duration + concurrency int + + mu sync.Mutex + queue refreshMinHeap + index map[string]*refreshHeapItem + dirty map[string]struct{} + + wakeCh chan struct{} + jobs chan string +} + +func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration, concurrency int) *authAutoRefreshLoop { + if interval <= 0 { + interval = refreshCheckInterval + } + if concurrency <= 0 { + concurrency = refreshMaxConcurrency + } + jobBuffer := concurrency * 4 + if jobBuffer < 64 { + jobBuffer = 64 + } + return &authAutoRefreshLoop{ + manager: manager, + interval: interval, + concurrency: concurrency, + index: make(map[string]*refreshHeapItem), + dirty: make(map[string]struct{}), + wakeCh: make(chan struct{}, 1), + jobs: make(chan string, jobBuffer), + } +} + +func (l *authAutoRefreshLoop) queueReschedule(authID string) { + if l == nil || authID == "" { + return + } + l.mu.Lock() + l.dirty[authID] = struct{}{} + l.mu.Unlock() + select { + case l.wakeCh <- struct{}{}: + default: + } +} + +func (l *authAutoRefreshLoop) run(ctx context.Context) { + if l == nil || l.manager == nil { + return + } + + workers := l.concurrency + if workers <= 0 { + workers = refreshMaxConcurrency + } + for i := 0; i < workers; i++ { + go l.worker(ctx) + } + + l.loop(ctx) +} + +func (l *authAutoRefreshLoop) worker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case authID := <-l.jobs: + if authID == "" { + continue + } + l.manager.refreshAuth(ctx, authID) + l.queueReschedule(authID) + } + } +} + +func (l *authAutoRefreshLoop) rebuild(now time.Time) { + type entry struct { + id string + next time.Time + } + + entries := make([]entry, 0) + + l.manager.mu.RLock() + for id, auth := range l.manager.auths { + next, ok := nextRefreshCheckAt(now, auth, l.interval) + if !ok { + continue + } + entries = append(entries, entry{id: id, next: next}) + } + l.manager.mu.RUnlock() + + l.mu.Lock() + l.queue = l.queue[:0] + l.index = make(map[string]*refreshHeapItem, len(entries)) + for _, e := range entries { + item := &refreshHeapItem{id: e.id, next: e.next} + heap.Push(&l.queue, item) + l.index[e.id] = item + } + l.mu.Unlock() +} + +func (l *authAutoRefreshLoop) loop(ctx context.Context) { + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + defer timer.Stop() + + var timerCh <-chan time.Time + l.resetTimer(timer, &timerCh, time.Now()) + + for { + select { + case <-ctx.Done(): + return + case <-l.wakeCh: + now := time.Now() + l.applyDirty(now) + l.resetTimer(timer, &timerCh, now) + case <-timerCh: + now := time.Now() + l.handleDue(ctx, now) + l.applyDirty(now) + l.resetTimer(timer, &timerCh, now) + } + } +} + +func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) { + next, ok := l.peek() + if !ok { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + *timerCh = nil + return + } + + wait := next.Sub(now) + if wait < 0 { + wait = 0 + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(wait) + *timerCh = timer.C +} + +func (l *authAutoRefreshLoop) peek() (time.Time, bool) { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.queue) == 0 { + return time.Time{}, false + } + return l.queue[0].next, true +} + +func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) { + due := l.popDue(now) + if len(due) == 0 { + return + } + if log.IsLevelEnabled(log.DebugLevel) { + log.Debugf("auto-refresh scheduler due auths: %d", len(due)) + } + for _, authID := range due { + l.handleDueAuth(ctx, now, authID) + } +} + +func (l *authAutoRefreshLoop) popDue(now time.Time) []string { + l.mu.Lock() + defer l.mu.Unlock() + + var due []string + for len(l.queue) > 0 { + item := l.queue[0] + if item == nil || item.next.After(now) { + break + } + popped := heap.Pop(&l.queue).(*refreshHeapItem) + if popped == nil { + continue + } + delete(l.index, popped.id) + due = append(due, popped.id) + } + return due +} + +func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) { + if authID == "" { + return + } + + manager := l.manager + + manager.mu.RLock() + auth := manager.auths[authID] + if auth == nil { + manager.mu.RUnlock() + return + } + next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval) + shouldRefresh := manager.shouldRefresh(auth, now) + exec := manager.executors[auth.Provider] + manager.mu.RUnlock() + + if !shouldSchedule { + l.remove(authID) + return + } + + if !shouldRefresh { + l.upsert(authID, next) + return + } + + if exec == nil { + l.upsert(authID, now.Add(l.interval)) + return + } + + if !manager.markRefreshPending(authID, now) { + manager.mu.RLock() + auth = manager.auths[authID] + next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval) + manager.mu.RUnlock() + if shouldSchedule { + l.upsert(authID, next) + } else { + l.remove(authID) + } + return + } + + select { + case <-ctx.Done(): + return + case l.jobs <- authID: + } +} + +func (l *authAutoRefreshLoop) applyDirty(now time.Time) { + dirty := l.drainDirty() + if len(dirty) == 0 { + return + } + + for _, authID := range dirty { + l.manager.mu.RLock() + auth := l.manager.auths[authID] + next, ok := nextRefreshCheckAt(now, auth, l.interval) + l.manager.mu.RUnlock() + + if !ok { + l.remove(authID) + continue + } + l.upsert(authID, next) + } +} + +func (l *authAutoRefreshLoop) drainDirty() []string { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.dirty) == 0 { + return nil + } + out := make([]string, 0, len(l.dirty)) + for authID := range l.dirty { + out = append(out, authID) + delete(l.dirty, authID) + } + return out +} + +func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) { + if authID == "" || next.IsZero() { + return + } + l.mu.Lock() + defer l.mu.Unlock() + if item, ok := l.index[authID]; ok && item != nil { + item.next = next + heap.Fix(&l.queue, item.index) + return + } + item := &refreshHeapItem{id: authID, next: next} + heap.Push(&l.queue, item) + l.index[authID] = item +} + +func (l *authAutoRefreshLoop) remove(authID string) { + if authID == "" { + return + } + l.mu.Lock() + defer l.mu.Unlock() + item, ok := l.index[authID] + if !ok || item == nil { + return + } + heap.Remove(&l.queue, item.index) + delete(l.index, authID) +} + +func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) { + if auth == nil || auth.Disabled { + return time.Time{}, false + } + + accountType, _ := auth.AccountInfo() + if accountType == "api_key" { + return time.Time{}, false + } + + if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) { + return auth.NextRefreshAfter, true + } + + if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil { + if interval <= 0 { + interval = refreshCheckInterval + } + return now.Add(interval), true + } + + lastRefresh := auth.LastRefreshedAt + if lastRefresh.IsZero() { + if ts, ok := authLastRefreshTimestamp(auth); ok { + lastRefresh = ts + } + } + + expiry, hasExpiry := auth.ExpirationTime() + + if pref := authPreferredInterval(auth); pref > 0 { + candidates := make([]time.Time, 0, 2) + if hasExpiry && !expiry.IsZero() { + if !expiry.After(now) || expiry.Sub(now) <= pref { + return now, true + } + candidates = append(candidates, expiry.Add(-pref)) + } + if lastRefresh.IsZero() { + return now, true + } + candidates = append(candidates, lastRefresh.Add(pref)) + next := candidates[0] + for _, candidate := range candidates[1:] { + if candidate.Before(next) { + next = candidate + } + } + if !next.After(now) { + return now, true + } + return next, true + } + + provider := strings.ToLower(auth.Provider) + lead := ProviderRefreshLead(provider, auth.Runtime) + if lead == nil { + return time.Time{}, false + } + if hasExpiry && !expiry.IsZero() { + dueAt := expiry.Add(-*lead) + if !dueAt.After(now) { + return now, true + } + return dueAt, true + } + if !lastRefresh.IsZero() { + dueAt := lastRefresh.Add(*lead) + if !dueAt.After(now) { + return now, true + } + return dueAt, true + } + return now, true +} + +type refreshHeapItem struct { + id string + next time.Time + index int +} + +type refreshMinHeap []*refreshHeapItem + +func (h refreshMinHeap) Len() int { return len(h) } + +func (h refreshMinHeap) Less(i, j int) bool { + return h[i].next.Before(h[j].next) +} + +func (h refreshMinHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +func (h *refreshMinHeap) Push(x any) { + item, ok := x.(*refreshHeapItem) + if !ok || item == nil { + return + } + item.index = len(*h) + *h = append(*h, item) +} + +func (h *refreshMinHeap) Pop() any { + old := *h + n := len(old) + if n == 0 { + return (*refreshHeapItem)(nil) + } + item := old[n-1] + item.index = -1 + *h = old[:n-1] + return item +} diff --git a/sdk/cliproxy/auth/auto_refresh_loop_test.go b/sdk/cliproxy/auth/auto_refresh_loop_test.go new file mode 100644 index 00000000..420aae23 --- /dev/null +++ b/sdk/cliproxy/auth/auto_refresh_loop_test.go @@ -0,0 +1,137 @@ +package auth + +import ( + "strings" + "testing" + "time" +) + +type testRefreshEvaluator struct{} + +func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false } + +func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) { + t.Helper() + key := strings.ToLower(strings.TrimSpace(provider)) + refreshLeadMu.Lock() + prev, hadPrev := refreshLeadFactories[key] + if factory == nil { + delete(refreshLeadFactories, key) + } else { + refreshLeadFactories[key] = factory + } + refreshLeadMu.Unlock() + t.Cleanup(func() { + refreshLeadMu.Lock() + if hadPrev { + refreshLeadFactories[key] = prev + } else { + delete(refreshLeadFactories, key) + } + refreshLeadMu.Unlock() + }) +} + +func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + auth := &Auth{ID: "a1", Provider: "test", Disabled: true} + if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok { + t.Fatalf("nextRefreshCheckAt() ok = true, want false") + } +} + +func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}} + if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok { + t.Fatalf("nextRefreshCheckAt() ok = true, want false") + } +} + +func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + nextAfter := now.Add(30 * time.Minute) + auth := &Auth{ + ID: "a1", + Provider: "test", + NextRefreshAfter: nextAfter, + Metadata: map[string]any{"email": "x@example.com"}, + } + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + if !got.Equal(nextAfter) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter) + } +} + +func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + expiry := now.Add(20 * time.Minute) + auth := &Auth{ + ID: "a1", + Provider: "test", + LastRefreshedAt: now, + Metadata: map[string]any{ + "email": "x@example.com", + "expires_at": expiry.Format(time.RFC3339), + "refresh_interval_seconds": 900, // 15m + }, + } + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := expiry.Add(-15 * time.Minute) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} + +func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + expiry := now.Add(time.Hour) + lead := 10 * time.Minute + setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration { + d := lead + return &d + }) + + auth := &Auth{ + ID: "a1", + Provider: "provider-lead-expiry", + Metadata: map[string]any{ + "email": "x@example.com", + "expires_at": expiry.Format(time.RFC3339), + }, + } + + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := expiry.Add(-lead) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} + +func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + interval := 15 * time.Minute + auth := &Auth{ + ID: "a1", + Provider: "test", + Metadata: map[string]any{"email": "x@example.com"}, + Runtime: testRefreshEvaluator{}, + } + got, ok := nextRefreshCheckAt(now, auth, interval) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := now.Add(interval) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 71587e89..6628c20b 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -105,6 +105,13 @@ type Selector interface { Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) } +// StoppableSelector is an optional interface for selectors that hold resources. +// Selectors that implement this interface will have Stop called during shutdown. +type StoppableSelector interface { + Selector + Stop() +} + // Hook captures lifecycle callbacks for observing auth changes. type Hook interface { // OnAuthRegistered fires when a new auth is registered. @@ -162,8 +169,8 @@ type Manager struct { rtProvider RoundTripperProvider // Auto refresh state - refreshCancel context.CancelFunc - refreshSemaphore chan struct{} + refreshCancel context.CancelFunc + refreshLoop *authAutoRefreshLoop } // NewManager constructs a manager with optional custom selector and hook. @@ -182,7 +189,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { auths: make(map[string]*Auth), providerOffsets: make(map[string]int), modelPoolOffsets: make(map[string]int), - refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) @@ -214,6 +220,16 @@ func (m *Manager) syncScheduler() { m.syncSchedulerFromSnapshot(m.snapshotAuths()) } +func (m *Manager) snapshotAuths() []*Auth { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]*Auth, 0, len(m.auths)) + for _, a := range m.auths { + out = append(out, a.Clone()) + } + return out +} + // RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its // supportedModelSet is rebuilt from the current global model registry state. // This must be called after models have been registered for a newly added auth, @@ -1088,6 +1104,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { if m.scheduler != nil { m.scheduler.upsertAuth(authClone) } + m.queueRefreshReschedule(auth.ID) _ = m.persist(ctx, auth) m.hook.OnAuthRegistered(ctx, auth.Clone()) return auth.Clone(), nil @@ -1118,6 +1135,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { if m.scheduler != nil { m.scheduler.upsertAuth(authClone) } + m.queueRefreshReschedule(auth.ID) _ = m.persist(ctx, auth) m.hook.OnAuthUpdated(ctx, auth.Clone()) return auth.Clone(), nil @@ -2890,80 +2908,60 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio if interval <= 0 { interval = refreshCheckInterval } - if m.refreshCancel != nil { - m.refreshCancel() - m.refreshCancel = nil + + m.mu.Lock() + cancelPrev := m.refreshCancel + m.refreshCancel = nil + m.refreshLoop = nil + m.mu.Unlock() + if cancelPrev != nil { + cancelPrev() } - ctx, cancel := context.WithCancel(parent) - m.refreshCancel = cancel - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - m.checkRefreshes(ctx) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - m.checkRefreshes(ctx) - } - } - }() + + ctx, cancelCtx := context.WithCancel(parent) + workers := refreshMaxConcurrency + if cfg, ok := m.runtimeConfig.Load().(*internalconfig.Config); ok && cfg != nil && cfg.AuthAutoRefreshWorkers > 0 { + workers = cfg.AuthAutoRefreshWorkers + } + loop := newAuthAutoRefreshLoop(m, interval, workers) + + m.mu.Lock() + m.refreshCancel = cancelCtx + m.refreshLoop = loop + m.mu.Unlock() + + loop.rebuild(time.Now()) + go loop.run(ctx) } // StopAutoRefresh cancels the background refresh loop, if running. +// It also stops the selector if it implements StoppableSelector. func (m *Manager) StopAutoRefresh() { - if m.refreshCancel != nil { - m.refreshCancel() - m.refreshCancel = nil + m.mu.Lock() + cancel := m.refreshCancel + m.refreshCancel = nil + m.refreshLoop = nil + m.mu.Unlock() + if cancel != nil { + cancel() + } + // Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector) + if stoppable, ok := m.selector.(StoppableSelector); ok { + stoppable.Stop() } } -func (m *Manager) checkRefreshes(ctx context.Context) { - // log.Debugf("checking refreshes") - now := time.Now() - snapshot := m.snapshotAuths() - for _, a := range snapshot { - typ, _ := a.AccountInfo() - if typ != "api_key" { - if !m.shouldRefresh(a, now) { - continue - } - log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ) - - if exec := m.executorFor(a.Provider); exec == nil { - continue - } - if !m.markRefreshPending(a.ID, now) { - continue - } - go m.refreshAuthWithLimit(ctx, a.ID) - } - } -} - -func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) { - if m.refreshSemaphore == nil { - m.refreshAuth(ctx, id) +func (m *Manager) queueRefreshReschedule(authID string) { + if m == nil || authID == "" { return } - select { - case m.refreshSemaphore <- struct{}{}: - defer func() { <-m.refreshSemaphore }() - case <-ctx.Done(): - return - } - m.refreshAuth(ctx, id) -} - -func (m *Manager) snapshotAuths() []*Auth { m.mu.RLock() - defer m.mu.RUnlock() - out := make([]*Auth, 0, len(m.auths)) - for _, a := range m.auths { - out = append(out, a.Clone()) + loop := m.refreshLoop + m.mu.RUnlock() + if loop == nil { + return } - return out + loop.queueReschedule(authID) } func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool { @@ -3173,16 +3171,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) { func (m *Manager) markRefreshPending(id string, now time.Time) bool { m.mu.Lock() - defer m.mu.Unlock() auth, ok := m.auths[id] if !ok || auth == nil || auth.Disabled { + m.mu.Unlock() return false } if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) { + m.mu.Unlock() return false } auth.NextRefreshAfter = now.Add(refreshPendingBackoff) m.auths[id] = auth + m.mu.Unlock() + + m.queueRefreshReschedule(id) return true } @@ -3209,16 +3211,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err) now := time.Now() if err != nil { + shouldReschedule := false m.mu.Lock() if current := m.auths[id]; current != nil { current.NextRefreshAfter = now.Add(refreshFailureBackoff) current.LastError = &Error{Message: err.Error()} m.auths[id] = current + shouldReschedule = true if m.scheduler != nil { m.scheduler.upsertAuth(current.Clone()) } } m.mu.Unlock() + if shouldReschedule { + m.queueRefreshReschedule(id) + } return } if updated == nil { diff --git a/sdk/cliproxy/auth/conductor_overrides_test.go b/sdk/cliproxy/auth/conductor_overrides_test.go index 1b74aab1..0adc83a6 100644 --- a/sdk/cliproxy/auth/conductor_overrides_test.go +++ b/sdk/cliproxy/auth/conductor_overrides_test.go @@ -69,18 +69,18 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing m := NewManager(nil, nil, nil) m.SetRetryConfig(3, 30*time.Second, 0) m.SetOAuthModelAlias(map[string][]internalconfig.OAuthModelAlias{ - "qwen": { - {Name: "qwen3.6-plus", Alias: "coder-model"}, + "iflow": { + {Name: "deepseek-v3.1", Alias: "pool-model"}, }, }) - routeModel := "coder-model" - upstreamModel := "qwen3.6-plus" + routeModel := "pool-model" + upstreamModel := "deepseek-v3.1" next := time.Now().Add(5 * time.Second) auth := &Auth{ ID: "auth-1", - Provider: "qwen", + Provider: "iflow", ModelStates: map[string]*ModelState{ upstreamModel: { Unavailable: true, @@ -99,7 +99,7 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing } _, _, maxWait := m.retrySettings() - wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"qwen"}, routeModel, maxWait) + wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"iflow"}, routeModel, maxWait) if !shouldRetry { t.Fatalf("expected shouldRetry=true, got false (wait=%v)", wait) } diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go index 7fc6a793..951cdecf 100644 --- a/sdk/cliproxy/auth/oauth_model_alias.go +++ b/sdk/cliproxy/auth/oauth_model_alias.go @@ -265,7 +265,7 @@ func modelAliasChannel(auth *Auth) string { // and auth kind. Returns empty string if the provider/authKind combination doesn't support // OAuth model alias (e.g., API key authentication). // -// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi. +// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi. func OAuthModelAliasChannel(provider, authKind string) string { provider = strings.ToLower(strings.TrimSpace(provider)) authKind = strings.ToLower(strings.TrimSpace(authKind)) @@ -289,7 +289,7 @@ func OAuthModelAliasChannel(provider, authKind string) string { return "" } return "codex" - case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi": + case "gemini-cli", "aistudio", "antigravity", "iflow", "kiro", "github-copilot", "kimi": return provider default: return "" diff --git a/sdk/cliproxy/auth/oauth_model_alias_test.go b/sdk/cliproxy/auth/oauth_model_alias_test.go index e12b6597..2d92b010 100644 --- a/sdk/cliproxy/auth/oauth_model_alias_test.go +++ b/sdk/cliproxy/auth/oauth_model_alias_test.go @@ -184,8 +184,6 @@ func createAuthForChannel(channel string) *Auth { return &Auth{Provider: "aistudio"} case "antigravity": return &Auth{Provider: "antigravity"} - case "qwen": - return &Auth{Provider: "qwen"} case "iflow": return &Auth{Provider: "iflow"} case "kimi": diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go index 9a977aae..ff2c4dd0 100644 --- a/sdk/cliproxy/auth/openai_compat_pool_test.go +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -215,10 +215,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} executor := &openAICompatPoolExecutor{ id: "pool", - countErrors: map[string]error{"qwen3.5-plus": invalidErr}, + countErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -227,18 +227,18 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi t.Fatalf("execute count error = %v, want %v", err, invalidErr) } got := executor.CountModels() - if len(got) != 1 || got[0] != "qwen3.5-plus" { + if len(got) != 1 || got[0] != "deepseek-v3.1" { t.Fatalf("count calls = %v, want only first invalid model", got) } } func TestResolveModelAliasPoolFromConfigModels(t *testing.T) { models := []modelAliasEntry{ - internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"}, + internalconfig.OpenAICompatibilityModel{Name: "deepseek-v3.1", Alias: "claude-opus-4.66"}, internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"}, internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"}, } got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models) - want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"} + want := []string{"deepseek-v3.1(8192)", "glm-5(8192)", "kimi-k2.5(8192)"} if len(got) != len(want) { t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got) } @@ -253,7 +253,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{id: "pool"} m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -268,7 +268,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { } got := executor.ExecuteModels() - want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"} + want := []string{"deepseek-v3.1", "glm-5", "deepseek-v3.1"} if len(got) != len(want) { t.Fatalf("execute calls = %v, want %v", got, want) } @@ -284,10 +284,10 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} executor := &openAICompatPoolExecutor{ id: "pool", - executeErrors: map[string]error{"qwen3.5-plus": invalidErr}, + executeErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -296,7 +296,7 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { t.Fatalf("execute error = %v, want %v", err, invalidErr) } got := executor.ExecuteModels() - if len(got) != 1 || got[0] != "qwen3.5-plus" { + if len(got) != 1 || got[0] != "deepseek-v3.1" { t.Fatalf("execute calls = %v, want only first invalid model", got) } } @@ -309,10 +309,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t } executor := &openAICompatPoolExecutor{ id: "pool", - executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -324,7 +324,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") } got := executor.ExecuteModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} if len(got) != len(want) { t.Fatalf("execute calls = %v, want %v", got, want) } @@ -338,7 +338,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t if !ok || updated == nil { t.Fatalf("expected auth to remain registered") } - state := updated.ModelStates["qwen3.5-plus"] + state := updated.ModelStates["deepseek-v3.1"] if state == nil { t.Fatalf("expected suspended upstream model state") } @@ -355,10 +355,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl } executor := &openAICompatPoolExecutor{ id: "pool", - executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -370,7 +370,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") } got := executor.ExecuteModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} if len(got) != len(want) { t.Fatalf("execute calls = %v, want %v", got, want) } @@ -385,10 +385,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing. alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ id: "pool", - executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + executeErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -400,7 +400,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing. t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") } got := executor.ExecuteModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} for i := range want { if got[i] != want[i] { t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) @@ -413,11 +413,11 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te executor := &openAICompatPoolExecutor{ id: "pool", streamPayloads: map[string][]cliproxyexecutor.StreamChunk{ - "qwen3.5-plus": {}, + "deepseek-v3.1": {}, }, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -436,7 +436,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te t.Fatalf("payload = %q, want %q", string(payload), "glm-5") } got := executor.StreamModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} for i := range want { if got[i] != want[i] { t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) @@ -448,10 +448,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ id: "pool", - streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + streamFirstErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -470,7 +470,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t t.Fatalf("payload = %q, want %q", string(payload), "glm-5") } got := executor.StreamModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} for i := range want { if got[i] != want[i] { t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) @@ -486,10 +486,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} executor := &openAICompatPoolExecutor{ id: "pool", - streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -498,7 +498,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test t.Fatalf("execute stream error = %v, want %v", err, invalidErr) } got := executor.StreamModels() - if len(got) != 1 || got[0] != "qwen3.5-plus" { + if len(got) != 1 || got[0] != "deepseek-v3.1" { t.Fatalf("stream calls = %v, want only first invalid model", got) } } @@ -511,10 +511,10 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques } executor := &openAICompatPoolExecutor{ id: "pool", - executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -529,7 +529,7 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques } got := executor.ExecuteModels() - want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"} if len(got) != len(want) { t.Fatalf("execute calls = %v, want %v", got, want) } @@ -548,10 +548,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater } executor := &openAICompatPoolExecutor{ id: "pool", - streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + streamFirstErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -569,7 +569,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater } got := executor.StreamModels() - want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"} if len(got) != len(want) { t.Fatalf("stream calls = %v, want %v", got, want) } @@ -584,7 +584,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{id: "pool"} m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -599,7 +599,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T } got := executor.CountModels() - want := []string{"qwen3.5-plus", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5"} for i := range want { if got[i] != want[i] { t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i]) @@ -615,10 +615,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR } executor := &openAICompatPoolExecutor{ id: "pool", - countErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + countErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -633,7 +633,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR } got := executor.CountModels() - want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"} if len(got) != len(want) { t.Fatalf("count calls = %v, want %v", got, want) } @@ -650,7 +650,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge OpenAICompatibility: []internalconfig.OpenAICompatibility{{ Name: "pool", Models: []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, }}, @@ -701,7 +701,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: The requested model is not supported.", } - for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} { + for _, upstreamModel := range []string{"deepseek-v3.1", "glm-5"} { m.MarkResult(context.Background(), Result{ AuthID: badAuth.ID, Provider: "pool", @@ -733,10 +733,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} executor := &openAICompatPoolExecutor{ id: "pool", - streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ - {Name: "qwen3.5-plus", Alias: alias}, + {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) @@ -750,7 +750,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te if streamResult != nil { t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult) } - if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" { + if got := executor.StreamModels(); len(got) != 1 || got[0] != "deepseek-v3.1" { t.Fatalf("stream calls = %v, want only first upstream model", got) } } diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index cf79e173..51275a31 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -4,15 +4,21 @@ import ( "context" "encoding/json" "fmt" + "hash/fnv" "math" "math/rand/v2" "net/http" + "regexp" "sort" "strconv" "strings" "sync" "time" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" ) @@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block } return false, blockReasonNone, time.Time{} } + +// sessionPattern matches Claude Code user_id format: +// user_{hash}_account__session_{uuid} +var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`) + +// SessionAffinitySelector wraps another selector with session-sticky behavior. +// It extracts session ID from multiple sources and maintains session-to-auth +// mappings with automatic failover when the bound auth becomes unavailable. +type SessionAffinitySelector struct { + fallback Selector + cache *SessionCache +} + +// SessionAffinityConfig configures the session affinity selector. +type SessionAffinityConfig struct { + Fallback Selector + TTL time.Duration +} + +// NewSessionAffinitySelector creates a new session-aware selector. +func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector { + return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Hour, + }) +} + +// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration. +func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector { + if cfg.Fallback == nil { + cfg.Fallback = &RoundRobinSelector{} + } + if cfg.TTL <= 0 { + cfg.TTL = time.Hour + } + return &SessionAffinitySelector{ + fallback: cfg.Fallback, + cache: NewSessionCache(cfg.TTL), + } +} + +// Pick selects an auth with session affinity when possible. +// Priority for session ID extraction: +// 1. metadata.user_id (Claude Code format) - highest priority +// 2. X-Session-ID header +// 3. metadata.user_id (non-Claude Code format) +// 4. conversation_id field +// 5. Hash-based fallback from messages +// +// Note: The cache key includes provider, session ID, and model to handle cases where +// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview) +// that may be supported by different auth credentials, and to avoid cross-provider conflicts. +func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + entry := selectorLogEntry(ctx) + primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata) + if primaryID == "" { + entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model) + return s.fallback.Pick(ctx, provider, model, opts, auths) + } + + now := time.Now() + available, err := getAvailableAuths(auths, provider, model, now) + if err != nil { + return nil, err + } + + cacheKey := provider + "::" + primaryID + "::" + model + + if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok { + for _, auth := range available { + if auth.ID == cachedAuthID { + entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil + } + } + // Cached auth not available, reselect via fallback selector for even distribution + auth, err := s.fallback.Pick(ctx, provider, model, opts, auths) + if err != nil { + return nil, err + } + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil + } + + if fallbackID != "" && fallbackID != primaryID { + fallbackKey := provider + "::" + fallbackID + "::" + model + if cachedAuthID, ok := s.cache.Get(fallbackKey); ok { + for _, auth := range available { + if auth.ID == cachedAuthID { + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model) + return auth, nil + } + } + } + } + + auth, err := s.fallback.Pick(ctx, provider, model, opts, auths) + if err != nil { + return nil, err + } + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil +} + +func selectorLogEntry(ctx context.Context) *log.Entry { + if ctx == nil { + return log.NewEntry(log.StandardLogger()) + } + if reqID := logging.GetRequestID(ctx); reqID != "" { + return log.WithField("request_id", reqID) + } + return log.NewEntry(log.StandardLogger()) +} + +// truncateSessionID shortens session ID for logging (first 8 chars + "...") +func truncateSessionID(id string) string { + if len(id) <= 20 { + return id + } + return id[:8] + "..." +} + +// Stop releases resources held by the selector. +func (s *SessionAffinitySelector) Stop() { + if s.cache != nil { + s.cache.Stop() + } +} + +// InvalidateAuth removes all session bindings for a specific auth. +// Called when an auth becomes rate-limited or unavailable. +func (s *SessionAffinitySelector) InvalidateAuth(authID string) { + if s.cache != nil { + s.cache.InvalidateAuth(authID) + } +} + +// ExtractSessionID extracts session identifier from multiple sources. +// Priority order: +// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients +// 2. X-Session-ID header +// 3. metadata.user_id (non-Claude Code format) +// 4. conversation_id field in request body +// 5. Stable hash from first few messages content (fallback) +func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string { + primary, _ := extractSessionIDs(headers, payload, metadata) + return primary +} + +// extractSessionIDs returns (primaryID, fallbackID) for session affinity. +// primaryID: full hash including assistant response (stable after first turn) +// fallbackID: short hash without assistant (used to inherit binding from first turn) +func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) { + // 1. metadata.user_id with Claude Code session format (highest priority) + if len(payload) > 0 { + userID := gjson.GetBytes(payload, "metadata.user_id").String() + if userID != "" { + // Old format: user_{hash}_account__session_{uuid} + if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 { + id := "claude:" + matches[1] + return id, "" + } + // New format: JSON object with session_id field + // e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"} + if len(userID) > 0 && userID[0] == '{' { + if sid := gjson.Get(userID, "session_id").String(); sid != "" { + return "claude:" + sid, "" + } + } + } + } + + // 2. X-Session-ID header + if headers != nil { + if sid := headers.Get("X-Session-ID"); sid != "" { + return "header:" + sid, "" + } + } + + if len(payload) == 0 { + return "", "" + } + + // 3. metadata.user_id (non-Claude Code format) + userID := gjson.GetBytes(payload, "metadata.user_id").String() + if userID != "" { + return "user:" + userID, "" + } + + // 4. conversation_id field + if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" { + return "conv:" + convID, "" + } + + // 5. Hash-based fallback from message content + return extractMessageHashIDs(payload) +} + +func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) { + var systemPrompt, firstUserMsg, firstAssistantMsg string + + // OpenAI/Claude messages format + messages := gjson.GetBytes(payload, "messages") + if messages.Exists() && messages.IsArray() { + messages.ForEach(func(_, msg gjson.Result) bool { + role := msg.Get("role").String() + content := extractMessageContent(msg.Get("content")) + if content == "" { + return true + } + + switch role { + case "system": + if systemPrompt == "" { + systemPrompt = truncateString(content, 100) + } + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(content, 100) + } + case "assistant": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(content, 100) + } + } + + if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + + // Claude API: top-level "system" field (array or string) + if systemPrompt == "" { + topSystem := gjson.GetBytes(payload, "system") + if topSystem.Exists() { + if topSystem.IsArray() { + topSystem.ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text").String(); text != "" && systemPrompt == "" { + systemPrompt = truncateString(text, 100) + return false + } + return true + }) + } else if topSystem.Type == gjson.String { + systemPrompt = truncateString(topSystem.String(), 100) + } + } + } + + // Gemini format + if systemPrompt == "" && firstUserMsg == "" { + sysInstr := gjson.GetBytes(payload, "systemInstruction.parts") + if sysInstr.Exists() && sysInstr.IsArray() { + sysInstr.ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text").String(); text != "" && systemPrompt == "" { + systemPrompt = truncateString(text, 100) + return false + } + return true + }) + } + + contents := gjson.GetBytes(payload, "contents") + if contents.Exists() && contents.IsArray() { + contents.ForEach(func(_, msg gjson.Result) bool { + role := msg.Get("role").String() + msg.Get("parts").ForEach(func(_, part gjson.Result) bool { + text := part.Get("text").String() + if text == "" { + return true + } + switch role { + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(text, 100) + } + case "model": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(text, 100) + } + } + return false + }) + if firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + } + + // OpenAI Responses API format (v1/responses) + if systemPrompt == "" && firstUserMsg == "" { + if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" { + systemPrompt = truncateString(instr, 100) + } + + input := gjson.GetBytes(payload, "input") + if input.Exists() && input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + itemType := item.Get("type").String() + if itemType == "reasoning" { + return true + } + // Skip non-message typed items (function_call, function_call_output, etc.) + // but allow items with no type that have a role (inline message format). + if itemType != "" && itemType != "message" { + return true + } + + role := item.Get("role").String() + if itemType == "" && role == "" { + return true + } + + // Handle both string content and array content (multimodal). + content := item.Get("content") + var text string + if content.Type == gjson.String { + text = content.String() + } else { + text = extractResponsesAPIContent(content) + } + if text == "" { + return true + } + + switch role { + case "developer", "system": + if systemPrompt == "" { + systemPrompt = truncateString(text, 100) + } + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(text, 100) + } + case "assistant": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(text, 100) + } + } + + if firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + } + + if systemPrompt == "" && firstUserMsg == "" { + return "", "" + } + + shortHash := computeSessionHash(systemPrompt, firstUserMsg, "") + if firstAssistantMsg == "" { + return shortHash, "" + } + + fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg) + return fullHash, shortHash +} + +func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string { + h := fnv.New64a() + if systemPrompt != "" { + h.Write([]byte("sys:" + systemPrompt + "\n")) + } + if userMsg != "" { + h.Write([]byte("usr:" + userMsg + "\n")) + } + if assistantMsg != "" { + h.Write([]byte("ast:" + assistantMsg + "\n")) + } + return fmt.Sprintf("msg:%016x", h.Sum64()) +} + +func truncateString(s string, maxLen int) string { + if len(s) > maxLen { + return s[:maxLen] + } + return s +} + +// extractMessageContent extracts text content from a message content field. +// Handles both string content and array content (multimodal messages). +// For array content, extracts text from all text-type elements. +func extractMessageContent(content gjson.Result) string { + // String content: "Hello world" + if content.Type == gjson.String { + return content.String() + } + + // Array content: [{"type":"text","text":"Hello"},{"type":"image",...}] + if content.IsArray() { + var texts []string + content.ForEach(func(_, part gjson.Result) bool { + // Handle Claude format: {"type":"text","text":"content"} + if part.Get("type").String() == "text" { + if text := part.Get("text").String(); text != "" { + texts = append(texts, text) + } + } + // Handle OpenAI format: {"type":"text","text":"content"} + // Same structure as Claude, already handled above + return true + }) + if len(texts) > 0 { + return strings.Join(texts, " ") + } + } + + return "" +} + +func extractResponsesAPIContent(content gjson.Result) string { + if !content.IsArray() { + return "" + } + var texts []string + content.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + if partType == "input_text" || partType == "output_text" || partType == "text" { + if text := part.Get("text").String(); text != "" { + texts = append(texts, text) + } + } + return true + }) + if len(texts) > 0 { + return strings.Join(texts, " ") + } + return "" +} + +// extractSessionID is kept for backward compatibility. +// Deprecated: Use ExtractSessionID instead. +func extractSessionID(payload []byte) string { + return ExtractSessionID(nil, payload, nil) +} diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index 79431a9a..560d3b9e 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" + "strings" "sync" "testing" "time" @@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { } } +func TestExtractSessionID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload string + want string + }{ + { + name: "valid_claude_code_format", + payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`, + want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344", + }, + { + name: "json_user_id_with_session_id", + payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`, + want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e", + }, + { + name: "json_user_id_without_session_id", + payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`, + want: `user:{"device_id":"abc123"}`, + }, + { + name: "no_session_but_user_id", + payload: `{"metadata":{"user_id":"user_abc123"}}`, + want: "user:user_abc123", + }, + { + name: "conversation_id", + payload: `{"conversation_id":"conv-12345"}`, + want: "conv:conv-12345", + }, + { + name: "no_metadata", + payload: `{"model":"claude-3"}`, + want: "", + }, + { + name: "empty_payload", + payload: ``, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractSessionID([]byte(tt.payload)) + if got != tt.want { + t.Errorf("extractSessionID() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + // Use valid UUID format for session ID + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Same session should always pick the same auth + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if first == nil { + t.Fatalf("Pick() returned nil") + } + + // Verify consistency: same session, same auths -> same result + for i := 0; i < 10; i++ { + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got.ID != first.ID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID) + } + } +} + +func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) { + t.Parallel() + + fallback := &FillFirstSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-b"}, + {ID: "auth-a"}, + {ID: "auth-c"}, + } + + // No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting) + payload := []byte(`{"model":"claude-3"}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got.ID != "auth-a" { + t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a") + } +} + +func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + // Use valid UUID format for session IDs + session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`) + session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`) + + opts1 := cliproxyexecutor.Options{OriginalRequest: session1} + opts2 := cliproxyexecutor.Options{OriginalRequest: session2} + + auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths) + auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths) + + // Different sessions may or may not pick different auths (depends on hash collision) + // But each session should be consistent + for i := 0; i < 5; i++ { + got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths) + got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths) + if got1.ID != auth1.ID { + t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID) + } + if got2.ID != auth2.ID { + t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID) + } + } +} + func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { t.Parallel() @@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { } } +func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // First pick establishes binding + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + + // Remove the bound auth from available list (simulating rate limit) + availableWithoutFirst := make([]*Auth, 0, len(auths)-1) + for _, a := range auths { + if a.ID != first.ID { + availableWithoutFirst = append(availableWithoutFirst, a) + } + } + + // With failover enabled, should pick a new auth + second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst) + if err != nil { + t.Fatalf("Pick() after failover error = %v", err) + } + if second.ID == first.ID { + t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID) + } + + // Subsequent picks should consistently return the new binding + for i := 0; i < 5; i++ { + got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst) + if got.ID != second.ID { + t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID) + } + } +} + func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { t.Parallel() @@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test } } } +func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) { + t.Parallel() + + // Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present + headers := make(http.Header) + headers.Set("X-Session-ID", "header-session-id") + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + + got := ExtractSessionID(headers, payload, nil) + want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344" + if got != want { + t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want) + } +} + +func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) { + t.Parallel() + + // Claude Code metadata.user_id should have highest priority, even when idempotency_key is present + metadata := map[string]any{"idempotency_key": "idem-12345"} + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + + got := ExtractSessionID(nil, payload, metadata) + want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344" + if got != want { + t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want) + } +} + +func TestExtractSessionID_Headers(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("X-Session-ID", "my-explicit-session") + + got := ExtractSessionID(headers, nil, nil) + want := "header:my-explicit-session" + if got != want { + t.Errorf("ExtractSessionID() with header = %q, want %q", got, want) + } +} + +// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally +// ignored for session affinity (it's auto-generated per-request, causing cache misses). +func TestExtractSessionID_IdempotencyKey(t *testing.T) { + t.Parallel() + + metadata := map[string]any{"idempotency_key": "idem-12345"} + + got := ExtractSessionID(nil, nil, metadata) + // idempotency_key is disabled - should return empty (no payload to hash) + if got != "" { + t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got) + } +} + +func TestExtractSessionID_MessageHashFallback(t *testing.T) { + t.Parallel() + + // First request (user only) generates short hash + firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`) + shortHash := ExtractSessionID(nil, firstRequestPayload, nil) + if shortHash == "" { + t.Error("ExtractSessionID() first request should return short hash") + } + if !strings.HasPrefix(shortHash, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash) + } + + // Multi-turn with assistant generates full hash (different from short hash) + multiTurnPayload := []byte(`{"messages":[ + {"role":"user","content":"Hello world"}, + {"role":"assistant","content":"Hi! How can I help?"}, + {"role":"user","content":"Tell me a joke"} + ]}`) + fullHash := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash == "" { + t.Error("ExtractSessionID() multi-turn should return full hash") + } + if fullHash == shortHash { + t.Error("Full hash should differ from short hash (includes assistant)") + } + + // Same multi-turn payload should produce same hash + fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash != fullHash2 { + t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2) + } +} + +func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) { + t.Parallel() + + // Claude API: system prompt in top-level "system" field (array format) + arraySystem := []byte(`{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "system": [{"type": "text", "text": "You are Claude Code"}] + }`) + got1 := ExtractSessionID(nil, arraySystem, nil) + if got1 == "" || !strings.HasPrefix(got1, "msg:") { + t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1) + } + + // Claude API: system prompt in top-level "system" field (string format) + stringSystem := []byte(`{ + "messages": [{"role": "user", "content": "Hello"}], + "system": "You are Claude Code" + }`) + got2 := ExtractSessionID(nil, stringSystem, nil) + if got2 == "" || !strings.HasPrefix(got2, "msg:") { + t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2) + } + + // Multi-turn with top-level system should produce stable hash + multiTurn := []byte(`{ + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Help me"} + ], + "system": "You are Claude Code" + }`) + got3 := ExtractSessionID(nil, multiTurn, nil) + if got3 == "" { + t.Error("ExtractSessionID() multi-turn with top-level system should return hash") + } + if got3 == got2 { + t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)") + } +} + +func TestExtractSessionID_GeminiFormat(t *testing.T) { + t.Parallel() + + // Gemini format with systemInstruction and contents + payload := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Hello Gemini"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]} + ] + }`) + + got := ExtractSessionID(nil, payload, nil) + if got == "" { + t.Error("ExtractSessionID() with Gemini format should return hash-based session ID") + } + if !strings.HasPrefix(got, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got) + } + + // Same payload should produce same hash + got2 := ExtractSessionID(nil, payload, nil) + if got != got2 { + t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2) + } + + // Different user message should produce different hash + differentPayload := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Hello different"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]} + ] + }`) + got3 := ExtractSessionID(nil, differentPayload, nil) + if got == got3 { + t.Errorf("ExtractSessionID() should produce different hash for different user message") + } +} + +func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) { + t.Parallel() + + firstTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]} + ] + }`) + + got1 := ExtractSessionID(nil, firstTurn, nil) + if got1 == "" { + t.Error("ExtractSessionID() should return hash for OpenAI Responses API format") + } + if !strings.HasPrefix(got1, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1) + } + + secondTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}, + {"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]} + ] + }`) + + got2 := ExtractSessionID(nil, secondTurn, nil) + if got2 == "" { + t.Error("ExtractSessionID() should return hash for second turn") + } + + if got1 == got2 { + t.Log("First turn and second turn have different hashes (expected: second includes assistant)") + } + + thirdTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}, + {"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]} + ] + }`) + + got3 := ExtractSessionID(nil, thirdTurn, nil) + if got2 != got3 { + t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3) + } +} + +func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}} + + testCases := []struct { + name string + scenario string + payload []byte + }{ + { + name: "OpenAI_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`), + }, + { + name: "OpenAI_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`), + }, + { + name: "OpenAI_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`), + }, + { + name: "Gemini_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`), + }, + { + name: "Gemini_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`), + }, + { + name: "Gemini_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`), + }, + { + name: "Claude_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`), + }, + { + name: "Claude_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`), + }, + { + name: "Claude_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + opts := cliproxyexecutor.Options{OriginalRequest: tc.payload} + picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if picked == nil { + t.Fatal("Pick() returned nil") + } + t.Logf("%s: picked %s", tc.name, picked.ID) + }) + } + + t.Run("Scenario2And3_SameAuth", func(t *testing.T) { + openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`) + openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`) + + opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2} + opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3} + + picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths) + picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths) + + if picked2.ID != picked3.ID { + t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID) + } + }) + + t.Run("Scenario1To2_InheritBinding", func(t *testing.T) { + s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`) + s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`) + + opts1 := cliproxyexecutor.Options{OriginalRequest: s1} + opts2 := cliproxyexecutor.Options{OriginalRequest: s2} + + picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths) + picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths) + + if picked1.ID != picked2.ID { + t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID) + } + }) +} + +func TestSessionAffinitySelector_MultiModelSession(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + // auth-a supports only model-a, auth-b supports only model-b + authA := &Auth{ID: "auth-a"} + authB := &Auth{ID: "auth-b"} + + // Same session ID for all requests + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Request model-a with only auth-a available for that model + authsForModelA := []*Auth{authA} + pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + if err != nil { + t.Fatalf("Pick() for model-a error = %v", err) + } + if pickedA.ID != "auth-a" { + t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID) + } + + // Request model-b with only auth-b available for that model + authsForModelB := []*Auth{authB} + pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB) + if err != nil { + t.Fatalf("Pick() for model-b error = %v", err) + } + if pickedB.ID != "auth-b" { + t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID) + } + + // Switch back to model-a - should still get auth-a (separate binding per model) + pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + if err != nil { + t.Fatalf("Pick() for model-a (2nd) error = %v", err) + } + if pickedA2.ID != "auth-a" { + t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID) + } + + // Verify bindings are stable for multiple calls + for i := 0; i < 5; i++ { + gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB) + if gotA.ID != "auth-a" { + t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID) + } + if gotB.ID != "auth-b" { + t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID) + } + } +} + +func TestExtractSessionID_MultimodalContent(t *testing.T) { + t.Parallel() + + // First request generates short hash + firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`) + shortHash := ExtractSessionID(nil, firstRequestPayload, nil) + if shortHash == "" { + t.Error("ExtractSessionID() first request should return short hash") + } + if !strings.HasPrefix(shortHash, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash) + } + + // Multi-turn generates full hash + multiTurnPayload := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}, + {"role":"assistant","content":"I see an image!"}, + {"role":"user","content":"What is it?"} + ]}`) + fullHash := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash == "" { + t.Error("ExtractSessionID() multimodal multi-turn should return full hash") + } + if fullHash == shortHash { + t.Error("Full hash should differ from short hash") + } + + // Different user content produces different hash + differentPayload := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"Different content"}]}, + {"role":"assistant","content":"I see something different!"} + ]}`) + differentHash := ExtractSessionID(nil, differentPayload, nil) + if fullHash == differentHash { + t.Errorf("ExtractSessionID() should produce different hash for different content") + } +} + +func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + authClaude := &Auth{ID: "auth-claude"} + authGemini := &Auth{ID: "auth-gemini"} + + // Same session ID for both providers + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Request via claude provider + pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude}) + if err != nil { + t.Fatalf("Pick() for claude error = %v", err) + } + if pickedClaude.ID != "auth-claude" { + t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID) + } + + // Same session but via gemini provider should get different auth + pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini}) + if err != nil { + t.Fatalf("Pick() for gemini error = %v", err) + } + if pickedGemini.ID != "auth-gemini" { + t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID) + } + + // Verify both bindings remain stable + for i := 0; i < 5; i++ { + gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude}) + gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini}) + if gotC.ID != "auth-claude" { + t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID) + } + if gotG.ID != "auth-gemini" { + t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID) + } + } +} + +func TestSessionCache_GetAndRefresh(t *testing.T) { + t.Parallel() + + cache := NewSessionCache(100 * time.Millisecond) + defer cache.Stop() + + cache.Set("session1", "auth1") + + // Verify initial value + got, ok := cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok) + } + + // Wait half TTL and access again (should refresh) + time.Sleep(60 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok) + } + + // Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms) + // Entry should still be valid because TTL was refreshed + time.Sleep(60 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok) + } + + // Now wait full TTL without access + time.Sleep(110 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if ok { + t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok) + } +} + +func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + sessionCount := 12 + counts := make(map[string]int) + for i := 0; i < sessionCount; i++ { + payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i)) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + got, err := selector.Pick(context.Background(), "provider", "model", opts, auths) + if err != nil { + t.Fatalf("Pick() session %d error = %v", i, err) + } + counts[got.ID]++ + } + + expected := sessionCount / len(auths) + for _, auth := range auths { + got := counts[auth.ID] + if got != expected { + t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected) + } + } +} + +func TestSessionAffinitySelector_Concurrent(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // First pick to establish binding + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Initial Pick() error = %v", err) + } + expectedID := first.ID + + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, 1) + + goroutines := 32 + iterations := 50 + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + for j := 0; j < iterations; j++ { + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + select { + case errCh <- err: + default: + } + return + } + if got.ID != expectedID { + select { + case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID): + default: + } + return + } + } + }() + } + + close(start) + wg.Wait() + + select { + case err := <-errCh: + t.Fatalf("concurrent Pick() error = %v", err) + default: + } +} diff --git a/sdk/cliproxy/auth/session_cache.go b/sdk/cliproxy/auth/session_cache.go new file mode 100644 index 00000000..a812e581 --- /dev/null +++ b/sdk/cliproxy/auth/session_cache.go @@ -0,0 +1,152 @@ +package auth + +import ( + "sync" + "time" +) + +// sessionEntry stores auth binding with expiration. +type sessionEntry struct { + authID string + expiresAt time.Time +} + +// SessionCache provides TTL-based session to auth mapping with automatic cleanup. +type SessionCache struct { + mu sync.RWMutex + entries map[string]sessionEntry + ttl time.Duration + stopCh chan struct{} +} + +// NewSessionCache creates a cache with the specified TTL. +// A background goroutine periodically cleans expired entries. +func NewSessionCache(ttl time.Duration) *SessionCache { + if ttl <= 0 { + ttl = 30 * time.Minute + } + c := &SessionCache{ + entries: make(map[string]sessionEntry), + ttl: ttl, + stopCh: make(chan struct{}), + } + go c.cleanupLoop() + return c +} + +// Get retrieves the auth ID bound to a session, if still valid. +// Does NOT refresh the TTL on access. +func (c *SessionCache) Get(sessionID string) (string, bool) { + if sessionID == "" { + return "", false + } + c.mu.RLock() + entry, ok := c.entries[sessionID] + c.mu.RUnlock() + if !ok { + return "", false + } + if time.Now().After(entry.expiresAt) { + c.mu.Lock() + delete(c.entries, sessionID) + c.mu.Unlock() + return "", false + } + return entry.authID, true +} + +// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit. +// This extends the binding lifetime for active sessions. +func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) { + if sessionID == "" { + return "", false + } + now := time.Now() + c.mu.Lock() + entry, ok := c.entries[sessionID] + if !ok { + c.mu.Unlock() + return "", false + } + if now.After(entry.expiresAt) { + delete(c.entries, sessionID) + c.mu.Unlock() + return "", false + } + // Refresh TTL on successful access + entry.expiresAt = now.Add(c.ttl) + c.entries[sessionID] = entry + c.mu.Unlock() + return entry.authID, true +} + +// Set binds a session to an auth ID with TTL refresh. +func (c *SessionCache) Set(sessionID, authID string) { + if sessionID == "" || authID == "" { + return + } + c.mu.Lock() + c.entries[sessionID] = sessionEntry{ + authID: authID, + expiresAt: time.Now().Add(c.ttl), + } + c.mu.Unlock() +} + +// Invalidate removes a specific session binding. +func (c *SessionCache) Invalidate(sessionID string) { + if sessionID == "" { + return + } + c.mu.Lock() + delete(c.entries, sessionID) + c.mu.Unlock() +} + +// InvalidateAuth removes all sessions bound to a specific auth ID. +// Used when an auth becomes unavailable. +func (c *SessionCache) InvalidateAuth(authID string) { + if authID == "" { + return + } + c.mu.Lock() + for sid, entry := range c.entries { + if entry.authID == authID { + delete(c.entries, sid) + } + } + c.mu.Unlock() +} + +// Stop terminates the background cleanup goroutine. +func (c *SessionCache) Stop() { + select { + case <-c.stopCh: + default: + close(c.stopCh) + } +} + +func (c *SessionCache) cleanupLoop() { + ticker := time.NewTicker(c.ttl / 2) + defer ticker.Stop() + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + c.cleanup() + } + } +} + +func (c *SessionCache) cleanup() { + now := time.Now() + c.mu.Lock() + for sid, entry := range c.entries { + if now.After(entry.expiresAt) { + delete(c.entries, sid) + } + } + c.mu.Unlock() +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 0e6d1421..b8cf991c 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -6,6 +6,7 @@ package cliproxy import ( "fmt" "strings" + "time" configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" @@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) { } strategy := "" + sessionAffinity := false + sessionAffinityTTL := time.Hour if b.cfg != nil { strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy)) + // Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity + sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity + if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" { + if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 { + sessionAffinityTTL = parsed + } + } } var selector coreauth.Selector switch strategy { @@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) { selector = &coreauth.RoundRobinSelector{} } + // Wrap with session affinity if enabled (failover is always on) + if sessionAffinity { + selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{ + Fallback: selector, + TTL: sessionAffinityTTL, + }) + } + coreManager = coreauth.NewManager(tokenStore, selector, nil) } // Attach a default RoundTripper provider so providers can opt-in per-auth transports. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 5165c540..c6a4f15e 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -118,7 +118,6 @@ func newDefaultAuthManager() *sdkAuth.Manager { sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), - sdkAuth.NewQwenAuthenticator(), sdkAuth.NewGitLabAuthenticator(), ) } @@ -435,8 +434,6 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg)) case "claude": s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) - case "qwen": - s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) case "kimi": @@ -639,9 +636,13 @@ func (s *Service) Run(ctx context.Context) error { var watcherWrapper *WatcherWrapper reloadCallback := func(newCfg *config.Config) { previousStrategy := "" + var previousSessionAffinity bool + var previousSessionAffinityTTL string s.cfgMu.RLock() if s.cfg != nil { previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy)) + previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity + previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL } s.cfgMu.RUnlock() @@ -665,7 +666,15 @@ func (s *Service) Run(ctx context.Context) error { } previousStrategy = normalizeStrategy(previousStrategy) nextStrategy = normalizeStrategy(nextStrategy) - if s.coreManager != nil && previousStrategy != nextStrategy { + + nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity + nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL + + selectorChanged := previousStrategy != nextStrategy || + previousSessionAffinity != nextSessionAffinity || + previousSessionAffinityTTL != nextSessionAffinityTTL + + if s.coreManager != nil && selectorChanged { var selector coreauth.Selector switch nextStrategy { case "fill-first": @@ -673,6 +682,20 @@ func (s *Service) Run(ctx context.Context) error { default: selector = &coreauth.RoundRobinSelector{} } + + if nextSessionAffinity { + ttl := time.Hour + if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" { + if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 { + ttl = parsed + } + } + selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{ + Fallback: selector, + TTL: ttl, + }) + } + s.coreManager.SetSelector(selector) } @@ -939,9 +962,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } models = applyExcludedModels(models, excluded) - case "qwen": - models = registry.GetQwenModels() - models = applyExcludedModels(models, excluded) case "iflow": models = registry.GetIFlowModels() models = applyExcludedModels(models, excluded)