From 7c24d54ca888f53bebe57a5db6e3f81a54e9ff32 Mon Sep 17 00:00:00 2001 From: sususu98 Date: Wed, 15 Apr 2026 00:48:08 +0800 Subject: [PATCH] feat(session-affinity): add session-sticky routing for multi-account load balancing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When multiple auth credentials are configured, requests from the same session are now routed to the same credential, improving upstream prompt cache hit rates and maintaining context continuity. Core components: - SessionAffinitySelector: wraps RoundRobin/FillFirst selectors with session-to-auth binding; automatic failover when bound auth is unavailable, re-binding via the fallback selector for even distribution - SessionCache: TTL-based in-memory cache with background cleanup goroutine, supporting per-session and per-auth invalidation - StoppableSelector interface: lifecycle hook for selectors holding resources, called during Manager.StopAutoRefresh() Session ID extraction priority (extractSessionIDs): 1. metadata.user_id with Claude Code session format (old user_{hash}_session_{uuid} and new JSON {session_id} format) 2. X-Session-ID header (generic client support) 3. metadata.user_id (non-Claude format, used as-is) 4. conversation_id field 5. Stable FNV hash from system prompt + first user/assistant messages (fallback for clients with no explicit session ID); returns both a full hash (primaryID) and a short hash without assistant content (fallbackID) to inherit bindings from the first turn Multi-format message hash covers OpenAI messages, Claude system array, Gemini contents/systemInstruction, and OpenAI Responses API input items (including inline messages with role but no type field). Configuration (config.yaml routing section): - session-affinity: bool (default false) - session-affinity-ttl: duration string (default "1h") - claude-code-session-affinity: bool (deprecated, alias for above) All three fields trigger selector rebuild on config hot reload. Side effect: Idempotency-Key header is no longer auto-generated with a random UUID when absent — only forwarded when explicitly provided by the client, to avoid polluting session hash extraction. --- config.example.yaml | 7 + internal/config/config.go | 16 + sdk/api/handlers/handlers.go | 5 +- sdk/cliproxy/auth/conductor.go | 12 + sdk/cliproxy/auth/selector.go | 451 ++++++++++++++++ sdk/cliproxy/auth/selector_test.go | 832 +++++++++++++++++++++++++++++ sdk/cliproxy/auth/session_cache.go | 152 ++++++ sdk/cliproxy/builder.go | 18 + sdk/cliproxy/service.go | 28 +- 9 files changed, 1517 insertions(+), 4 deletions(-) create mode 100644 sdk/cliproxy/auth/session_cache.go diff --git a/config.example.yaml b/config.example.yaml index b8440f7a..f423f818 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -103,6 +103,13 @@ quota-exceeded: # Routing strategy for selecting credentials when multiple match. routing: strategy: "round-robin" # round-robin (default), fill-first + # Enable universal session-sticky routing for all clients. + # Session IDs are extracted from: X-Session-ID header, Idempotency-Key, + # metadata.user_id, conversation_id, or first few messages hash. + # Automatic failover is always enabled when bound auth becomes unavailable. + session-affinity: false # default: false + # How long session-to-auth bindings are retained. Default: 1h + session-affinity-ttl: "1h" # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false diff --git a/internal/config/config.go b/internal/config/config.go index 8527f6b2..7b2f9611 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -216,6 +216,22 @@ type RoutingConfig struct { // Strategy selects the credential selection strategy. // Supported values: "round-robin" (default), "fill-first". Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` + + // ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients. + // When enabled, requests with the same session ID (extracted from metadata.user_id) + // are routed to the same auth credential when available. + // Deprecated: Use SessionAffinity instead for universal session support. + ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"` + + // SessionAffinity enables universal session-sticky routing for all clients. + // Session IDs are extracted from multiple sources: + // X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash. + // Automatic failover is always enabled when bound auth becomes unavailable. + SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"` + + // SessionAffinityTTL specifies how long session-to-auth bindings are retained. + // Default: 1h. Accepts duration strings like "30m", "1h", "2h30m". + SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"` } // OAuthModelAlias defines a model ID alias for a specific channel. diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 1f7996c0..6734d500 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -14,7 +14,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" @@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool { func requestExecutionMetadata(ctx context.Context) map[string]any { // Idempotency-Key is an optional client-supplied header used to correlate retries. - // It is forwarded as execution metadata; when absent we generate a UUID. + // Only include it if the client explicitly provides it. key := "" if ctx != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { @@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { } } if key == "" { - key = uuid.NewString() + return make(map[string]any) } meta := map[string]any{idempotencyKeyMetadataKey: key} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 5e7d3161..f5872203 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -105,6 +105,13 @@ type Selector interface { Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) } +// StoppableSelector is an optional interface for selectors that hold resources. +// Selectors that implement this interface will have Stop called during shutdown. +type StoppableSelector interface { + Selector + Stop() +} + // Hook captures lifecycle callbacks for observing auth changes. type Hook interface { // OnAuthRegistered fires when a new auth is registered. @@ -2928,6 +2935,7 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio } // StopAutoRefresh cancels the background refresh loop, if running. +// It also stops the selector if it implements StoppableSelector. func (m *Manager) StopAutoRefresh() { m.mu.Lock() cancel := m.refreshCancel @@ -2937,6 +2945,10 @@ func (m *Manager) StopAutoRefresh() { if cancel != nil { cancel() } + // Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector) + if stoppable, ok := m.selector.(StoppableSelector); ok { + stoppable.Stop() + } } func (m *Manager) queueRefreshReschedule(authID string) { diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index cf79e173..51275a31 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -4,15 +4,21 @@ import ( "context" "encoding/json" "fmt" + "hash/fnv" "math" "math/rand/v2" "net/http" + "regexp" "sort" "strconv" "strings" "sync" "time" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" ) @@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block } return false, blockReasonNone, time.Time{} } + +// sessionPattern matches Claude Code user_id format: +// user_{hash}_account__session_{uuid} +var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`) + +// SessionAffinitySelector wraps another selector with session-sticky behavior. +// It extracts session ID from multiple sources and maintains session-to-auth +// mappings with automatic failover when the bound auth becomes unavailable. +type SessionAffinitySelector struct { + fallback Selector + cache *SessionCache +} + +// SessionAffinityConfig configures the session affinity selector. +type SessionAffinityConfig struct { + Fallback Selector + TTL time.Duration +} + +// NewSessionAffinitySelector creates a new session-aware selector. +func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector { + return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Hour, + }) +} + +// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration. +func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector { + if cfg.Fallback == nil { + cfg.Fallback = &RoundRobinSelector{} + } + if cfg.TTL <= 0 { + cfg.TTL = time.Hour + } + return &SessionAffinitySelector{ + fallback: cfg.Fallback, + cache: NewSessionCache(cfg.TTL), + } +} + +// Pick selects an auth with session affinity when possible. +// Priority for session ID extraction: +// 1. metadata.user_id (Claude Code format) - highest priority +// 2. X-Session-ID header +// 3. metadata.user_id (non-Claude Code format) +// 4. conversation_id field +// 5. Hash-based fallback from messages +// +// Note: The cache key includes provider, session ID, and model to handle cases where +// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview) +// that may be supported by different auth credentials, and to avoid cross-provider conflicts. +func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + entry := selectorLogEntry(ctx) + primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata) + if primaryID == "" { + entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model) + return s.fallback.Pick(ctx, provider, model, opts, auths) + } + + now := time.Now() + available, err := getAvailableAuths(auths, provider, model, now) + if err != nil { + return nil, err + } + + cacheKey := provider + "::" + primaryID + "::" + model + + if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok { + for _, auth := range available { + if auth.ID == cachedAuthID { + entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil + } + } + // Cached auth not available, reselect via fallback selector for even distribution + auth, err := s.fallback.Pick(ctx, provider, model, opts, auths) + if err != nil { + return nil, err + } + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil + } + + if fallbackID != "" && fallbackID != primaryID { + fallbackKey := provider + "::" + fallbackID + "::" + model + if cachedAuthID, ok := s.cache.Get(fallbackKey); ok { + for _, auth := range available { + if auth.ID == cachedAuthID { + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model) + return auth, nil + } + } + } + } + + auth, err := s.fallback.Pick(ctx, provider, model, opts, auths) + if err != nil { + return nil, err + } + s.cache.Set(cacheKey, auth.ID) + entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model) + return auth, nil +} + +func selectorLogEntry(ctx context.Context) *log.Entry { + if ctx == nil { + return log.NewEntry(log.StandardLogger()) + } + if reqID := logging.GetRequestID(ctx); reqID != "" { + return log.WithField("request_id", reqID) + } + return log.NewEntry(log.StandardLogger()) +} + +// truncateSessionID shortens session ID for logging (first 8 chars + "...") +func truncateSessionID(id string) string { + if len(id) <= 20 { + return id + } + return id[:8] + "..." +} + +// Stop releases resources held by the selector. +func (s *SessionAffinitySelector) Stop() { + if s.cache != nil { + s.cache.Stop() + } +} + +// InvalidateAuth removes all session bindings for a specific auth. +// Called when an auth becomes rate-limited or unavailable. +func (s *SessionAffinitySelector) InvalidateAuth(authID string) { + if s.cache != nil { + s.cache.InvalidateAuth(authID) + } +} + +// ExtractSessionID extracts session identifier from multiple sources. +// Priority order: +// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients +// 2. X-Session-ID header +// 3. metadata.user_id (non-Claude Code format) +// 4. conversation_id field in request body +// 5. Stable hash from first few messages content (fallback) +func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string { + primary, _ := extractSessionIDs(headers, payload, metadata) + return primary +} + +// extractSessionIDs returns (primaryID, fallbackID) for session affinity. +// primaryID: full hash including assistant response (stable after first turn) +// fallbackID: short hash without assistant (used to inherit binding from first turn) +func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) { + // 1. metadata.user_id with Claude Code session format (highest priority) + if len(payload) > 0 { + userID := gjson.GetBytes(payload, "metadata.user_id").String() + if userID != "" { + // Old format: user_{hash}_account__session_{uuid} + if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 { + id := "claude:" + matches[1] + return id, "" + } + // New format: JSON object with session_id field + // e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"} + if len(userID) > 0 && userID[0] == '{' { + if sid := gjson.Get(userID, "session_id").String(); sid != "" { + return "claude:" + sid, "" + } + } + } + } + + // 2. X-Session-ID header + if headers != nil { + if sid := headers.Get("X-Session-ID"); sid != "" { + return "header:" + sid, "" + } + } + + if len(payload) == 0 { + return "", "" + } + + // 3. metadata.user_id (non-Claude Code format) + userID := gjson.GetBytes(payload, "metadata.user_id").String() + if userID != "" { + return "user:" + userID, "" + } + + // 4. conversation_id field + if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" { + return "conv:" + convID, "" + } + + // 5. Hash-based fallback from message content + return extractMessageHashIDs(payload) +} + +func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) { + var systemPrompt, firstUserMsg, firstAssistantMsg string + + // OpenAI/Claude messages format + messages := gjson.GetBytes(payload, "messages") + if messages.Exists() && messages.IsArray() { + messages.ForEach(func(_, msg gjson.Result) bool { + role := msg.Get("role").String() + content := extractMessageContent(msg.Get("content")) + if content == "" { + return true + } + + switch role { + case "system": + if systemPrompt == "" { + systemPrompt = truncateString(content, 100) + } + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(content, 100) + } + case "assistant": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(content, 100) + } + } + + if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + + // Claude API: top-level "system" field (array or string) + if systemPrompt == "" { + topSystem := gjson.GetBytes(payload, "system") + if topSystem.Exists() { + if topSystem.IsArray() { + topSystem.ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text").String(); text != "" && systemPrompt == "" { + systemPrompt = truncateString(text, 100) + return false + } + return true + }) + } else if topSystem.Type == gjson.String { + systemPrompt = truncateString(topSystem.String(), 100) + } + } + } + + // Gemini format + if systemPrompt == "" && firstUserMsg == "" { + sysInstr := gjson.GetBytes(payload, "systemInstruction.parts") + if sysInstr.Exists() && sysInstr.IsArray() { + sysInstr.ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text").String(); text != "" && systemPrompt == "" { + systemPrompt = truncateString(text, 100) + return false + } + return true + }) + } + + contents := gjson.GetBytes(payload, "contents") + if contents.Exists() && contents.IsArray() { + contents.ForEach(func(_, msg gjson.Result) bool { + role := msg.Get("role").String() + msg.Get("parts").ForEach(func(_, part gjson.Result) bool { + text := part.Get("text").String() + if text == "" { + return true + } + switch role { + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(text, 100) + } + case "model": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(text, 100) + } + } + return false + }) + if firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + } + + // OpenAI Responses API format (v1/responses) + if systemPrompt == "" && firstUserMsg == "" { + if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" { + systemPrompt = truncateString(instr, 100) + } + + input := gjson.GetBytes(payload, "input") + if input.Exists() && input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + itemType := item.Get("type").String() + if itemType == "reasoning" { + return true + } + // Skip non-message typed items (function_call, function_call_output, etc.) + // but allow items with no type that have a role (inline message format). + if itemType != "" && itemType != "message" { + return true + } + + role := item.Get("role").String() + if itemType == "" && role == "" { + return true + } + + // Handle both string content and array content (multimodal). + content := item.Get("content") + var text string + if content.Type == gjson.String { + text = content.String() + } else { + text = extractResponsesAPIContent(content) + } + if text == "" { + return true + } + + switch role { + case "developer", "system": + if systemPrompt == "" { + systemPrompt = truncateString(text, 100) + } + case "user": + if firstUserMsg == "" { + firstUserMsg = truncateString(text, 100) + } + case "assistant": + if firstAssistantMsg == "" { + firstAssistantMsg = truncateString(text, 100) + } + } + + if firstUserMsg != "" && firstAssistantMsg != "" { + return false + } + return true + }) + } + } + + if systemPrompt == "" && firstUserMsg == "" { + return "", "" + } + + shortHash := computeSessionHash(systemPrompt, firstUserMsg, "") + if firstAssistantMsg == "" { + return shortHash, "" + } + + fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg) + return fullHash, shortHash +} + +func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string { + h := fnv.New64a() + if systemPrompt != "" { + h.Write([]byte("sys:" + systemPrompt + "\n")) + } + if userMsg != "" { + h.Write([]byte("usr:" + userMsg + "\n")) + } + if assistantMsg != "" { + h.Write([]byte("ast:" + assistantMsg + "\n")) + } + return fmt.Sprintf("msg:%016x", h.Sum64()) +} + +func truncateString(s string, maxLen int) string { + if len(s) > maxLen { + return s[:maxLen] + } + return s +} + +// extractMessageContent extracts text content from a message content field. +// Handles both string content and array content (multimodal messages). +// For array content, extracts text from all text-type elements. +func extractMessageContent(content gjson.Result) string { + // String content: "Hello world" + if content.Type == gjson.String { + return content.String() + } + + // Array content: [{"type":"text","text":"Hello"},{"type":"image",...}] + if content.IsArray() { + var texts []string + content.ForEach(func(_, part gjson.Result) bool { + // Handle Claude format: {"type":"text","text":"content"} + if part.Get("type").String() == "text" { + if text := part.Get("text").String(); text != "" { + texts = append(texts, text) + } + } + // Handle OpenAI format: {"type":"text","text":"content"} + // Same structure as Claude, already handled above + return true + }) + if len(texts) > 0 { + return strings.Join(texts, " ") + } + } + + return "" +} + +func extractResponsesAPIContent(content gjson.Result) string { + if !content.IsArray() { + return "" + } + var texts []string + content.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + if partType == "input_text" || partType == "output_text" || partType == "text" { + if text := part.Get("text").String(); text != "" { + texts = append(texts, text) + } + } + return true + }) + if len(texts) > 0 { + return strings.Join(texts, " ") + } + return "" +} + +// extractSessionID is kept for backward compatibility. +// Deprecated: Use ExtractSessionID instead. +func extractSessionID(payload []byte) string { + return ExtractSessionID(nil, payload, nil) +} diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index 79431a9a..560d3b9e 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" + "strings" "sync" "testing" "time" @@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { } } +func TestExtractSessionID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload string + want string + }{ + { + name: "valid_claude_code_format", + payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`, + want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344", + }, + { + name: "json_user_id_with_session_id", + payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`, + want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e", + }, + { + name: "json_user_id_without_session_id", + payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`, + want: `user:{"device_id":"abc123"}`, + }, + { + name: "no_session_but_user_id", + payload: `{"metadata":{"user_id":"user_abc123"}}`, + want: "user:user_abc123", + }, + { + name: "conversation_id", + payload: `{"conversation_id":"conv-12345"}`, + want: "conv:conv-12345", + }, + { + name: "no_metadata", + payload: `{"model":"claude-3"}`, + want: "", + }, + { + name: "empty_payload", + payload: ``, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractSessionID([]byte(tt.payload)) + if got != tt.want { + t.Errorf("extractSessionID() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + // Use valid UUID format for session ID + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Same session should always pick the same auth + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if first == nil { + t.Fatalf("Pick() returned nil") + } + + // Verify consistency: same session, same auths -> same result + for i := 0; i < 10; i++ { + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got.ID != first.ID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID) + } + } +} + +func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) { + t.Parallel() + + fallback := &FillFirstSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-b"}, + {ID: "auth-a"}, + {ID: "auth-c"}, + } + + // No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting) + payload := []byte(`{"model":"claude-3"}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got.ID != "auth-a" { + t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a") + } +} + +func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelector(fallback) + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + // Use valid UUID format for session IDs + session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`) + session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`) + + opts1 := cliproxyexecutor.Options{OriginalRequest: session1} + opts2 := cliproxyexecutor.Options{OriginalRequest: session2} + + auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths) + auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths) + + // Different sessions may or may not pick different auths (depends on hash collision) + // But each session should be consistent + for i := 0; i < 5; i++ { + got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths) + got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths) + if got1.ID != auth1.ID { + t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID) + } + if got2.ID != auth2.ID { + t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID) + } + } +} + func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { t.Parallel() @@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { } } +func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // First pick establishes binding + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + + // Remove the bound auth from available list (simulating rate limit) + availableWithoutFirst := make([]*Auth, 0, len(auths)-1) + for _, a := range auths { + if a.ID != first.ID { + availableWithoutFirst = append(availableWithoutFirst, a) + } + } + + // With failover enabled, should pick a new auth + second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst) + if err != nil { + t.Fatalf("Pick() after failover error = %v", err) + } + if second.ID == first.ID { + t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID) + } + + // Subsequent picks should consistently return the new binding + for i := 0; i < 5; i++ { + got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst) + if got.ID != second.ID { + t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID) + } + } +} + func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { t.Parallel() @@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test } } } +func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) { + t.Parallel() + + // Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present + headers := make(http.Header) + headers.Set("X-Session-ID", "header-session-id") + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + + got := ExtractSessionID(headers, payload, nil) + want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344" + if got != want { + t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want) + } +} + +func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) { + t.Parallel() + + // Claude Code metadata.user_id should have highest priority, even when idempotency_key is present + metadata := map[string]any{"idempotency_key": "idem-12345"} + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) + + got := ExtractSessionID(nil, payload, metadata) + want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344" + if got != want { + t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want) + } +} + +func TestExtractSessionID_Headers(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("X-Session-ID", "my-explicit-session") + + got := ExtractSessionID(headers, nil, nil) + want := "header:my-explicit-session" + if got != want { + t.Errorf("ExtractSessionID() with header = %q, want %q", got, want) + } +} + +// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally +// ignored for session affinity (it's auto-generated per-request, causing cache misses). +func TestExtractSessionID_IdempotencyKey(t *testing.T) { + t.Parallel() + + metadata := map[string]any{"idempotency_key": "idem-12345"} + + got := ExtractSessionID(nil, nil, metadata) + // idempotency_key is disabled - should return empty (no payload to hash) + if got != "" { + t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got) + } +} + +func TestExtractSessionID_MessageHashFallback(t *testing.T) { + t.Parallel() + + // First request (user only) generates short hash + firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`) + shortHash := ExtractSessionID(nil, firstRequestPayload, nil) + if shortHash == "" { + t.Error("ExtractSessionID() first request should return short hash") + } + if !strings.HasPrefix(shortHash, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash) + } + + // Multi-turn with assistant generates full hash (different from short hash) + multiTurnPayload := []byte(`{"messages":[ + {"role":"user","content":"Hello world"}, + {"role":"assistant","content":"Hi! How can I help?"}, + {"role":"user","content":"Tell me a joke"} + ]}`) + fullHash := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash == "" { + t.Error("ExtractSessionID() multi-turn should return full hash") + } + if fullHash == shortHash { + t.Error("Full hash should differ from short hash (includes assistant)") + } + + // Same multi-turn payload should produce same hash + fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash != fullHash2 { + t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2) + } +} + +func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) { + t.Parallel() + + // Claude API: system prompt in top-level "system" field (array format) + arraySystem := []byte(`{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "system": [{"type": "text", "text": "You are Claude Code"}] + }`) + got1 := ExtractSessionID(nil, arraySystem, nil) + if got1 == "" || !strings.HasPrefix(got1, "msg:") { + t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1) + } + + // Claude API: system prompt in top-level "system" field (string format) + stringSystem := []byte(`{ + "messages": [{"role": "user", "content": "Hello"}], + "system": "You are Claude Code" + }`) + got2 := ExtractSessionID(nil, stringSystem, nil) + if got2 == "" || !strings.HasPrefix(got2, "msg:") { + t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2) + } + + // Multi-turn with top-level system should produce stable hash + multiTurn := []byte(`{ + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Help me"} + ], + "system": "You are Claude Code" + }`) + got3 := ExtractSessionID(nil, multiTurn, nil) + if got3 == "" { + t.Error("ExtractSessionID() multi-turn with top-level system should return hash") + } + if got3 == got2 { + t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)") + } +} + +func TestExtractSessionID_GeminiFormat(t *testing.T) { + t.Parallel() + + // Gemini format with systemInstruction and contents + payload := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Hello Gemini"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]} + ] + }`) + + got := ExtractSessionID(nil, payload, nil) + if got == "" { + t.Error("ExtractSessionID() with Gemini format should return hash-based session ID") + } + if !strings.HasPrefix(got, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got) + } + + // Same payload should produce same hash + got2 := ExtractSessionID(nil, payload, nil) + if got != got2 { + t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2) + } + + // Different user message should produce different hash + differentPayload := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Hello different"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]} + ] + }`) + got3 := ExtractSessionID(nil, differentPayload, nil) + if got == got3 { + t.Errorf("ExtractSessionID() should produce different hash for different user message") + } +} + +func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) { + t.Parallel() + + firstTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]} + ] + }`) + + got1 := ExtractSessionID(nil, firstTurn, nil) + if got1 == "" { + t.Error("ExtractSessionID() should return hash for OpenAI Responses API format") + } + if !strings.HasPrefix(got1, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1) + } + + secondTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}, + {"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]} + ] + }`) + + got2 := ExtractSessionID(nil, secondTurn, nil) + if got2 == "" { + t.Error("ExtractSessionID() should return hash for second turn") + } + + if got1 == got2 { + t.Log("First turn and second turn have different hashes (expected: second includes assistant)") + } + + thirdTurn := []byte(`{ + "instructions": "You are Codex, based on GPT-5.", + "input": [ + {"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}, + {"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]} + ] + }`) + + got3 := ExtractSessionID(nil, thirdTurn, nil) + if got2 != got3 { + t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3) + } +} + +func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}} + + testCases := []struct { + name string + scenario string + payload []byte + }{ + { + name: "OpenAI_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`), + }, + { + name: "OpenAI_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`), + }, + { + name: "OpenAI_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`), + }, + { + name: "Gemini_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`), + }, + { + name: "Gemini_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`), + }, + { + name: "Gemini_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`), + }, + { + name: "Claude_Scenario1_NewRequest", + scenario: "new", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`), + }, + { + name: "Claude_Scenario2_SecondTurn", + scenario: "second", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`), + }, + { + name: "Claude_Scenario3_ManyTurns", + scenario: "many", + payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + opts := cliproxyexecutor.Options{OriginalRequest: tc.payload} + picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if picked == nil { + t.Fatal("Pick() returned nil") + } + t.Logf("%s: picked %s", tc.name, picked.ID) + }) + } + + t.Run("Scenario2And3_SameAuth", func(t *testing.T) { + openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`) + openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`) + + opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2} + opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3} + + picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths) + picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths) + + if picked2.ID != picked3.ID { + t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID) + } + }) + + t.Run("Scenario1To2_InheritBinding", func(t *testing.T) { + s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`) + s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`) + + opts1 := cliproxyexecutor.Options{OriginalRequest: s1} + opts2 := cliproxyexecutor.Options{OriginalRequest: s2} + + picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths) + picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths) + + if picked1.ID != picked2.ID { + t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID) + } + }) +} + +func TestSessionAffinitySelector_MultiModelSession(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + // auth-a supports only model-a, auth-b supports only model-b + authA := &Auth{ID: "auth-a"} + authB := &Auth{ID: "auth-b"} + + // Same session ID for all requests + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Request model-a with only auth-a available for that model + authsForModelA := []*Auth{authA} + pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + if err != nil { + t.Fatalf("Pick() for model-a error = %v", err) + } + if pickedA.ID != "auth-a" { + t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID) + } + + // Request model-b with only auth-b available for that model + authsForModelB := []*Auth{authB} + pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB) + if err != nil { + t.Fatalf("Pick() for model-b error = %v", err) + } + if pickedB.ID != "auth-b" { + t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID) + } + + // Switch back to model-a - should still get auth-a (separate binding per model) + pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + if err != nil { + t.Fatalf("Pick() for model-a (2nd) error = %v", err) + } + if pickedA2.ID != "auth-a" { + t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID) + } + + // Verify bindings are stable for multiple calls + for i := 0; i < 5; i++ { + gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA) + gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB) + if gotA.ID != "auth-a" { + t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID) + } + if gotB.ID != "auth-b" { + t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID) + } + } +} + +func TestExtractSessionID_MultimodalContent(t *testing.T) { + t.Parallel() + + // First request generates short hash + firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`) + shortHash := ExtractSessionID(nil, firstRequestPayload, nil) + if shortHash == "" { + t.Error("ExtractSessionID() first request should return short hash") + } + if !strings.HasPrefix(shortHash, "msg:") { + t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash) + } + + // Multi-turn generates full hash + multiTurnPayload := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}, + {"role":"assistant","content":"I see an image!"}, + {"role":"user","content":"What is it?"} + ]}`) + fullHash := ExtractSessionID(nil, multiTurnPayload, nil) + if fullHash == "" { + t.Error("ExtractSessionID() multimodal multi-turn should return full hash") + } + if fullHash == shortHash { + t.Error("Full hash should differ from short hash") + } + + // Different user content produces different hash + differentPayload := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"Different content"}]}, + {"role":"assistant","content":"I see something different!"} + ]}`) + differentHash := ExtractSessionID(nil, differentPayload, nil) + if fullHash == differentHash { + t.Errorf("ExtractSessionID() should produce different hash for different content") + } +} + +func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + authClaude := &Auth{ID: "auth-claude"} + authGemini := &Auth{ID: "auth-gemini"} + + // Same session ID for both providers + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // Request via claude provider + pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude}) + if err != nil { + t.Fatalf("Pick() for claude error = %v", err) + } + if pickedClaude.ID != "auth-claude" { + t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID) + } + + // Same session but via gemini provider should get different auth + pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini}) + if err != nil { + t.Fatalf("Pick() for gemini error = %v", err) + } + if pickedGemini.ID != "auth-gemini" { + t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID) + } + + // Verify both bindings remain stable + for i := 0; i < 5; i++ { + gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude}) + gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini}) + if gotC.ID != "auth-claude" { + t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID) + } + if gotG.ID != "auth-gemini" { + t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID) + } + } +} + +func TestSessionCache_GetAndRefresh(t *testing.T) { + t.Parallel() + + cache := NewSessionCache(100 * time.Millisecond) + defer cache.Stop() + + cache.Set("session1", "auth1") + + // Verify initial value + got, ok := cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok) + } + + // Wait half TTL and access again (should refresh) + time.Sleep(60 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok) + } + + // Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms) + // Entry should still be valid because TTL was refreshed + time.Sleep(60 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if !ok || got != "auth1" { + t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok) + } + + // Now wait full TTL without access + time.Sleep(110 * time.Millisecond) + got, ok = cache.GetAndRefresh("session1") + if ok { + t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok) + } +} + +func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + sessionCount := 12 + counts := make(map[string]int) + for i := 0; i < sessionCount; i++ { + payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i)) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + got, err := selector.Pick(context.Background(), "provider", "model", opts, auths) + if err != nil { + t.Fatalf("Pick() session %d error = %v", i, err) + } + counts[got.ID]++ + } + + expected := sessionCount / len(auths) + for _, auth := range auths { + got := counts[auth.ID] + if got != expected { + t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected) + } + } +} + +func TestSessionAffinitySelector_Concurrent(t *testing.T) { + t.Parallel() + + fallback := &RoundRobinSelector{} + selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{ + Fallback: fallback, + TTL: time.Minute, + }) + defer selector.Stop() + + auths := []*Auth{ + {ID: "auth-a"}, + {ID: "auth-b"}, + {ID: "auth-c"}, + } + + payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`) + opts := cliproxyexecutor.Options{OriginalRequest: payload} + + // First pick to establish binding + first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + t.Fatalf("Initial Pick() error = %v", err) + } + expectedID := first.ID + + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, 1) + + goroutines := 32 + iterations := 50 + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + for j := 0; j < iterations; j++ { + got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths) + if err != nil { + select { + case errCh <- err: + default: + } + return + } + if got.ID != expectedID { + select { + case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID): + default: + } + return + } + } + }() + } + + close(start) + wg.Wait() + + select { + case err := <-errCh: + t.Fatalf("concurrent Pick() error = %v", err) + default: + } +} diff --git a/sdk/cliproxy/auth/session_cache.go b/sdk/cliproxy/auth/session_cache.go new file mode 100644 index 00000000..a812e581 --- /dev/null +++ b/sdk/cliproxy/auth/session_cache.go @@ -0,0 +1,152 @@ +package auth + +import ( + "sync" + "time" +) + +// sessionEntry stores auth binding with expiration. +type sessionEntry struct { + authID string + expiresAt time.Time +} + +// SessionCache provides TTL-based session to auth mapping with automatic cleanup. +type SessionCache struct { + mu sync.RWMutex + entries map[string]sessionEntry + ttl time.Duration + stopCh chan struct{} +} + +// NewSessionCache creates a cache with the specified TTL. +// A background goroutine periodically cleans expired entries. +func NewSessionCache(ttl time.Duration) *SessionCache { + if ttl <= 0 { + ttl = 30 * time.Minute + } + c := &SessionCache{ + entries: make(map[string]sessionEntry), + ttl: ttl, + stopCh: make(chan struct{}), + } + go c.cleanupLoop() + return c +} + +// Get retrieves the auth ID bound to a session, if still valid. +// Does NOT refresh the TTL on access. +func (c *SessionCache) Get(sessionID string) (string, bool) { + if sessionID == "" { + return "", false + } + c.mu.RLock() + entry, ok := c.entries[sessionID] + c.mu.RUnlock() + if !ok { + return "", false + } + if time.Now().After(entry.expiresAt) { + c.mu.Lock() + delete(c.entries, sessionID) + c.mu.Unlock() + return "", false + } + return entry.authID, true +} + +// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit. +// This extends the binding lifetime for active sessions. +func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) { + if sessionID == "" { + return "", false + } + now := time.Now() + c.mu.Lock() + entry, ok := c.entries[sessionID] + if !ok { + c.mu.Unlock() + return "", false + } + if now.After(entry.expiresAt) { + delete(c.entries, sessionID) + c.mu.Unlock() + return "", false + } + // Refresh TTL on successful access + entry.expiresAt = now.Add(c.ttl) + c.entries[sessionID] = entry + c.mu.Unlock() + return entry.authID, true +} + +// Set binds a session to an auth ID with TTL refresh. +func (c *SessionCache) Set(sessionID, authID string) { + if sessionID == "" || authID == "" { + return + } + c.mu.Lock() + c.entries[sessionID] = sessionEntry{ + authID: authID, + expiresAt: time.Now().Add(c.ttl), + } + c.mu.Unlock() +} + +// Invalidate removes a specific session binding. +func (c *SessionCache) Invalidate(sessionID string) { + if sessionID == "" { + return + } + c.mu.Lock() + delete(c.entries, sessionID) + c.mu.Unlock() +} + +// InvalidateAuth removes all sessions bound to a specific auth ID. +// Used when an auth becomes unavailable. +func (c *SessionCache) InvalidateAuth(authID string) { + if authID == "" { + return + } + c.mu.Lock() + for sid, entry := range c.entries { + if entry.authID == authID { + delete(c.entries, sid) + } + } + c.mu.Unlock() +} + +// Stop terminates the background cleanup goroutine. +func (c *SessionCache) Stop() { + select { + case <-c.stopCh: + default: + close(c.stopCh) + } +} + +func (c *SessionCache) cleanupLoop() { + ticker := time.NewTicker(c.ttl / 2) + defer ticker.Stop() + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + c.cleanup() + } + } +} + +func (c *SessionCache) cleanup() { + now := time.Now() + c.mu.Lock() + for sid, entry := range c.entries { + if now.After(entry.expiresAt) { + delete(c.entries, sid) + } + } + c.mu.Unlock() +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 0e6d1421..b8cf991c 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -6,6 +6,7 @@ package cliproxy import ( "fmt" "strings" + "time" configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" @@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) { } strategy := "" + sessionAffinity := false + sessionAffinityTTL := time.Hour if b.cfg != nil { strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy)) + // Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity + sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity + if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" { + if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 { + sessionAffinityTTL = parsed + } + } } var selector coreauth.Selector switch strategy { @@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) { selector = &coreauth.RoundRobinSelector{} } + // Wrap with session affinity if enabled (failover is always on) + if sessionAffinity { + selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{ + Fallback: selector, + TTL: sessionAffinityTTL, + }) + } + coreManager = coreauth.NewManager(tokenStore, selector, nil) } // Attach a default RoundTripper provider so providers can opt-in per-auth transports. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index dd22987f..3471994e 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -612,9 +612,13 @@ func (s *Service) Run(ctx context.Context) error { var watcherWrapper *WatcherWrapper reloadCallback := func(newCfg *config.Config) { previousStrategy := "" + var previousSessionAffinity bool + var previousSessionAffinityTTL string s.cfgMu.RLock() if s.cfg != nil { previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy)) + previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity + previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL } s.cfgMu.RUnlock() @@ -638,7 +642,15 @@ func (s *Service) Run(ctx context.Context) error { } previousStrategy = normalizeStrategy(previousStrategy) nextStrategy = normalizeStrategy(nextStrategy) - if s.coreManager != nil && previousStrategy != nextStrategy { + + nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity + nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL + + selectorChanged := previousStrategy != nextStrategy || + previousSessionAffinity != nextSessionAffinity || + previousSessionAffinityTTL != nextSessionAffinityTTL + + if s.coreManager != nil && selectorChanged { var selector coreauth.Selector switch nextStrategy { case "fill-first": @@ -646,6 +658,20 @@ func (s *Service) Run(ctx context.Context) error { default: selector = &coreauth.RoundRobinSelector{} } + + if nextSessionAffinity { + ttl := time.Hour + if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" { + if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 { + ttl = parsed + } + } + selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{ + Fallback: selector, + TTL: ttl, + }) + } + s.coreManager.SetSelector(selector) }