diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index bfac49216..574338fd7 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -104,6 +104,15 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + sessionAuthByID := func(authID string) (*coreauth.Auth, bool) { + if h == nil || h.AuthManager == nil { + return nil, false + } + if auth, ok := h.AuthManager.GetExecutionSessionAuthByID(passthroughSessionID, authID); ok { + return auth, true + } + return h.AuthManager.GetByID(authID) + } forceTranscriptReplayNextRequest := false for { @@ -130,8 +139,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now()) allowIncrementalInputWithPreviousResponseID := false - if pinnedAuthID != "" && h != nil && h.AuthManager != nil { - if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + if pinnedAuthID != "" { + if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) } } else { @@ -146,8 +155,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } allowCompactionReplayBypass := false - if pinnedAuthID != "" && h != nil && h.AuthManager != nil { - if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + if pinnedAuthID != "" { + if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth) } } else { @@ -228,7 +237,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if authID == "" || h == nil || h.AuthManager == nil { return } - selectedAuth, ok := h.AuthManager.GetByID(authID) + selectedAuth, ok := sessionAuthByID(authID) if !ok || selectedAuth == nil { return } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 64a28d586..5d6a30356 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -153,9 +153,7 @@ type Manager struct { scheduler *authScheduler // homeRuntimeAuths caches auths returned by Home so websocket sessions can // reuse an established upstream credential without dispatching every turn. - homeRuntimeAuths map[string]*Auth - homeRuntimeAuthSessions map[string]map[string]struct{} - homeRuntimeAuthRefs map[string]int + homeRuntimeAuths map[string]map[string]*Auth // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int @@ -195,16 +193,14 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { hook = NoopHook{} } manager := &Manager{ - store: store, - executors: make(map[string]ProviderExecutor), - selector: selector, - hook: hook, - auths: make(map[string]*Auth), - homeRuntimeAuths: make(map[string]*Auth), - homeRuntimeAuthSessions: make(map[string]map[string]struct{}), - homeRuntimeAuthRefs: make(map[string]int), - providerOffsets: make(map[string]int), - modelPoolOffsets: make(map[string]int), + store: store, + executors: make(map[string]ProviderExecutor), + selector: selector, + hook: hook, + auths: make(map[string]*Auth), + homeRuntimeAuths: make(map[string]map[string]*Auth), + providerOffsets: make(map[string]int), + modelPoolOffsets: make(map[string]int), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) @@ -2724,10 +2720,24 @@ func (m *Manager) GetByID(id string) (*Auth, bool) { defer m.mu.RUnlock() auth, ok := m.auths[id] if !ok { - auth, ok = m.homeRuntimeAuths[id] - if !ok { - return nil, false - } + return nil, false + } + return auth.Clone(), true +} + +// GetExecutionSessionAuthByID retrieves a Home runtime auth scoped to an execution session. +func (m *Manager) GetExecutionSessionAuthByID(sessionID string, authID string) (*Auth, bool) { + sessionID = strings.TrimSpace(sessionID) + authID = strings.TrimSpace(authID) + if m == nil || sessionID == "" || authID == "" { + return nil, false + } + m.mu.RLock() + defer m.mu.RUnlock() + sessionAuths := m.homeRuntimeAuths[sessionID] + auth := sessionAuths[authID] + if auth == nil { + return nil, false } return auth.Clone(), true } @@ -3218,9 +3228,7 @@ func (m *Manager) clearHomeRuntimeAuthsLocked() { if m == nil { return } - m.homeRuntimeAuths = make(map[string]*Auth) - m.homeRuntimeAuthSessions = make(map[string]map[string]struct{}) - m.homeRuntimeAuthRefs = make(map[string]int) + m.homeRuntimeAuths = make(map[string]map[string]*Auth) } func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) { @@ -3228,21 +3236,7 @@ func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) { if m == nil || sessionID == "" { return } - authIDs := m.homeRuntimeAuthSessions[sessionID] - if len(authIDs) == 0 { - delete(m.homeRuntimeAuthSessions, sessionID) - return - } - for authID := range authIDs { - refCount := m.homeRuntimeAuthRefs[authID] - if refCount <= 1 { - delete(m.homeRuntimeAuthRefs, authID) - delete(m.homeRuntimeAuths, authID) - continue - } - m.homeRuntimeAuthRefs[authID] = refCount - 1 - } - delete(m.homeRuntimeAuthSessions, sessionID) + delete(m.homeRuntimeAuths, sessionID) } func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) { @@ -3256,24 +3250,14 @@ func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) { } m.mu.Lock() if m.homeRuntimeAuths == nil { - m.homeRuntimeAuths = make(map[string]*Auth) + m.homeRuntimeAuths = make(map[string]map[string]*Auth) } - if m.homeRuntimeAuthSessions == nil { - m.homeRuntimeAuthSessions = make(map[string]map[string]struct{}) - } - if m.homeRuntimeAuthRefs == nil { - m.homeRuntimeAuthRefs = make(map[string]int) - } - m.homeRuntimeAuths[authID] = auth.Clone() - sessionAuths := m.homeRuntimeAuthSessions[sessionID] + sessionAuths := m.homeRuntimeAuths[sessionID] if sessionAuths == nil { - sessionAuths = make(map[string]struct{}) - m.homeRuntimeAuthSessions[sessionID] = sessionAuths - } - if _, exists := sessionAuths[authID]; !exists { - sessionAuths[authID] = struct{}{} - m.homeRuntimeAuthRefs[authID]++ + sessionAuths = make(map[string]*Auth) + m.homeRuntimeAuths[sessionID] = sessionAuths } + sessionAuths[authID] = auth.Clone() m.mu.Unlock() } @@ -3284,12 +3268,8 @@ func (m *Manager) homeRuntimeAuthByID(sessionID string, authID string) (*Auth, P return nil, nil, "", false } m.mu.RLock() - sessionAuths := m.homeRuntimeAuthSessions[sessionID] - if _, ok := sessionAuths[authID]; !ok { - m.mu.RUnlock() - return nil, nil, "", false - } - auth := m.homeRuntimeAuths[authID] + sessionAuths := m.homeRuntimeAuths[sessionID] + auth := sessionAuths[authID] m.mu.RUnlock() if auth == nil || !authWebsocketsEnabled(auth) { return nil, nil, "", false diff --git a/sdk/cliproxy/auth/home_websocket_reuse_test.go b/sdk/cliproxy/auth/home_websocket_reuse_test.go index 284dd076f..28d480042 100644 --- a/sdk/cliproxy/auth/home_websocket_reuse_test.go +++ b/sdk/cliproxy/auth/home_websocket_reuse_test.go @@ -27,9 +27,9 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing. } auth.EnsureIndex() manager.rememberHomeRuntimeAuth("session-1", auth) - cachedAuth, ok := manager.GetByID("home-auth-1") + cachedAuth, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1") if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) { - t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok) + t.Fatalf("GetExecutionSessionAuthByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok) } ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) @@ -56,6 +56,61 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing. } } +func TestPickNextViaHomeKeepsSameAuthIDPayloadSessionScoped(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(schedulerTestExecutor{}) + + manager.rememberHomeRuntimeAuth("session-1", &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + Attributes: map[string]string{ + "websockets": "true", + homeUpstreamModelAttributeKey: "upstream-model-a", + }, + }) + manager.rememberHomeRuntimeAuth("session-2", &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + Attributes: map[string]string{ + "websockets": "true", + homeUpstreamModelAttributeKey: "upstream-model-b", + }, + }) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + optsSession1 := cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-1", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + } + optsSession2 := cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-2", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + } + + gotSession1, _, _, errSession1 := manager.pickNextViaHome(ctx, "gpt-5.4", optsSession1, nil) + if errSession1 != nil { + t.Fatalf("pickNextViaHome(session-1) error = %v", errSession1) + } + if got := gotSession1.Attributes[homeUpstreamModelAttributeKey]; got != "upstream-model-a" { + t.Fatalf("pickNextViaHome(session-1) upstream model = %q, want upstream-model-a", got) + } + + gotSession2, _, _, errSession2 := manager.pickNextViaHome(ctx, "gpt-5.4", optsSession2, nil) + if errSession2 != nil { + t.Fatalf("pickNextViaHome(session-2) error = %v", errSession2) + } + if got := gotSession2.Attributes[homeUpstreamModelAttributeKey]; got != "upstream-model-b" { + t.Fatalf("pickNextViaHome(session-2) upstream model = %q, want upstream-model-b", got) + } +} + func TestPickNextViaHomeDoesNotReuseTriedPinnedWebsocketAuth(t *testing.T) { manager := NewManager(nil, nil, nil) manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) @@ -135,10 +190,12 @@ func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) { manager.RegisterExecutor(schedulerTestExecutor{}) manager.mu.Lock() - manager.homeRuntimeAuths["home-auth-1"] = &Auth{ - ID: "home-auth-1", - Provider: "test", - Status: StatusActive, + manager.homeRuntimeAuths["session-1"] = map[string]*Auth{ + "home-auth-1": &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + }, } manager.mu.Unlock() @@ -175,12 +232,12 @@ func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) { }, }) - if _, ok := manager.GetByID("home-auth-1"); !ok { + if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); !ok { t.Fatal("expected remembered home auth before disabling home") } manager.SetConfig(&internalconfig.Config{}) - if _, ok := manager.GetByID("home-auth-1"); ok { + if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); ok { t.Fatal("remembered home auth was not cleared when home was disabled") } } @@ -199,12 +256,15 @@ func TestCloseExecutionSessionClearsHomeRuntimeAuthForSession(t *testing.T) { manager.rememberHomeRuntimeAuth("session-2", auth) manager.CloseExecutionSession("session-1") - if _, ok := manager.GetByID("home-auth-1"); !ok { - t.Fatal("shared home auth was cleared while another session still referenced it") + if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); ok { + t.Fatal("home auth for closed session was not cleared") + } + if _, ok := manager.GetExecutionSessionAuthByID("session-2", "home-auth-1"); !ok { + t.Fatal("home auth for another session was cleared") } manager.CloseExecutionSession("session-2") - if _, ok := manager.GetByID("home-auth-1"); ok { + if _, ok := manager.GetExecutionSessionAuthByID("session-2", "home-auth-1"); ok { t.Fatal("home auth was not cleared when its last session closed") } }