diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 8113cdbbc..318d5dc14 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -272,6 +272,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { lastResponseID := "" var lastResponsePendingToolCallIDs []string pinnedAuthID := "" + passthroughModelName := "" sessionAuthByID := func(authID string) (*coreauth.Auth, bool) { if h == nil || h.AuthManager == nil { return nil, false @@ -307,47 +308,47 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { wsTimelineLog.BeginRequest() wsTimelineLog.Append("request", payload, time.Now()) + requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestModelName == "" { + requestModelName = passthroughModelName + } + if requestModelName == "" { + requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + } + useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName) allowIncrementalInputWithPreviousResponseID := false - if pinnedAuthID != "" { - if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { - allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) - } - } else { - requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - if requestModelName == "" { - requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) - } - allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) - } - if forceTranscriptReplayNextRequest { - allowIncrementalInputWithPreviousResponseID = false - } - allowCompactionReplayBypass := false - if pinnedAuthID != "" { - if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { - allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth) + if !useCodexWebsocketPassthrough { + if pinnedAuthID != "" { + if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { + allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth) + } + } else { + allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) + allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName) } - } else { - requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - if requestModelName == "" { - requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if forceTranscriptReplayNextRequest { + allowIncrementalInputWithPreviousResponseID = false } - allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName) } var requestJSON []byte var updatedLastRequest []byte var errMsg *interfaces.ErrorMessage - requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState( - payload, - lastRequest, - lastResponseOutput, - lastResponseID, - lastResponsePendingToolCallIDs, - allowIncrementalInputWithPreviousResponseID, - allowCompactionReplayBypass, - ) + if useCodexWebsocketPassthrough { + requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName) + } else { + requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState( + payload, + lastRequest, + lastResponseOutput, + lastResponseID, + lastResponsePendingToolCallIDs, + allowIncrementalInputWithPreviousResponseID, + allowCompactionReplayBypass, + ) + } if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) @@ -370,7 +371,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { + if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { requestJSON = updated } @@ -388,17 +389,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { continue } - requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON) - requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON) - updatedLastRequest = bytes.Clone(requestJSON) previousLastRequest := bytes.Clone(lastRequest) previousLastResponseOutput := bytes.Clone(lastResponseOutput) previousLastResponseID := lastResponseID previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...) forcedTranscriptReplay := forceTranscriptReplayNextRequest - lastRequest = updatedLastRequest - if forcedTranscriptReplay { - forceTranscriptReplayNextRequest = false + if useCodexWebsocketPassthrough { + if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" { + passthroughModelName = modelName + } + if forcedTranscriptReplay { + forceTranscriptReplayNextRequest = false + } + } else { + requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON) + requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON) + updatedLastRequest = bytes.Clone(requestJSON) + lastRequest = updatedLastRequest + if forcedTranscriptReplay { + forceTranscriptReplayNextRequest = false + } } modelName := gjson.GetBytes(requestJSON, "model").String() @@ -433,15 +443,21 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) { pinnedAuthID = "" forceTranscriptReplayNextRequest = true - lastRequest = previousLastRequest - lastResponseOutput = previousLastResponseOutput - lastResponseID = previousLastResponseID - lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs + if useCodexWebsocketPassthrough { + passthroughModelName = "" + } else { + lastRequest = previousLastRequest + lastResponseOutput = previousLastResponseOutput + lastResponseID = previousLastResponseID + lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs + } continue } - lastResponseOutput = completedOutput - lastResponseID = strings.TrimSpace(completedResponseID) - lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...) + if !useCodexWebsocketPassthrough { + lastResponseOutput = completedOutput + lastResponseID = strings.TrimSpace(completedResponseID) + lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...) + } } } @@ -944,6 +960,65 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod return available, modelKey } +func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool { + modelName = strings.TrimSpace(modelName) + if h == nil || h.AuthManager == nil || modelName == "" { + return false + } + if _, ok := h.AuthManager.Executor("codex"); !ok { + return false + } + auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName) + if len(auths) == 0 { + return false + } + for _, auth := range auths { + if auth == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + return false + } + if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) { + return false + } + } + return true +} + +func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) { + if !json.Valid(rawJSON) { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("invalid websocket request JSON"), + } + } + + requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) + switch requestType { + case wsRequestTypeCreate, wsRequestTypeAppend: + default: + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("unsupported websocket request type: %s", requestType), + } + } + + normalized := bytes.Clone(rawJSON) + if strings.TrimSpace(gjson.GetBytes(normalized, "model").String()) == "" { + modelName = strings.TrimSpace(modelName) + if modelName == "" { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("missing model in response.create request"), + } + } + normalized, _ = sjson.SetBytes(normalized, "model", modelName) + } + normalized, _ = sjson.SetBytes(normalized, "stream", true) + return normalized, nil +} + func responsesWebsocketResolvedModelName(modelName string) string { initialSuffix := thinking.ParseSuffix(modelName) if initialSuffix.ModelName == "auto" { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index b67147f08..99f4e555f 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -83,6 +83,14 @@ type websocketBootstrapFallbackExecutor struct { payloads map[string][][]byte } +type websocketDirectCaptureExecutor struct { + mu sync.Mutex + authIDs []string + payloads [][]byte + done chan struct{} + doneOnce sync.Once +} + type websocketPinnedFailoverStatusError struct { status int msg string @@ -156,6 +164,63 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte { return out } +func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" } + +func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketDirectCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + authID := "" + if auth != nil { + authID = auth.ID + } + e.mu.Lock() + e.authIDs = append(e.authIDs, authID) + e.payloads = append(e.payloads, bytes.Clone(req.Payload)) + count := len(e.payloads) + e.mu.Unlock() + + chunks := make(chan coreexecutor.StreamChunk, 1) + responseID := fmt.Sprintf("resp-%d", count) + chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":%q,"output":[{"type":"message","id":"out-%d"}]}}`, responseID, count))} + close(chunks) + if count >= 2 && e.done != nil { + e.doneOnce.Do(func() { + close(e.done) + }) + } + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketDirectCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketDirectCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketDirectCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketDirectCaptureExecutor) Payloads() [][]byte { + e.mu.Lock() + defer e.mu.Unlock() + out := make([][]byte, len(e.payloads)) + for i := range e.payloads { + out[i] = bytes.Clone(e.payloads[i]) + } + return out +} + +func (e *websocketDirectCaptureExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + type websocketUpstreamDisconnectExecutor struct { mu sync.Mutex subscribed chan string @@ -1497,6 +1562,85 @@ func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) { } } +func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithoutTranscriptMerge(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketDirectCaptureExecutor{done: make(chan struct{})} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ + ID: "auth-ws", + Provider: "codex", + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + firstRequest := []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","role":"user","content":"first"}]}`) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { _ = conn.Close() }() + + if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil { + t.Fatalf("write first websocket message: %v", errWrite) + } + if _, _, errRead := conn.ReadMessage(); errRead != nil { + t.Fatalf("read first websocket response: %v", errRead) + } + + compactedRequest := []byte(`{"type":"response.create","input":[{"type":"compaction_summary","summary":"compressed history"},{"type":"message","role":"user","content":"after compaction"}]}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, compactedRequest); errWrite != nil { + t.Fatalf("write compacted websocket message: %v", errWrite) + } + if _, _, errRead := conn.ReadMessage(); errRead != nil { + t.Fatalf("read compacted websocket response: %v", errRead) + } + + select { + case <-executor.done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for websocket passthrough") + } + + payloads := executor.Payloads() + if len(payloads) != 2 { + t.Fatalf("passthrough payload count = %d, want 2", len(payloads)) + } + if got := gjson.GetBytes(payloads[0], "input").Raw; got != gjson.GetBytes(firstRequest, "input").Raw { + t.Fatalf("first passthrough input = %s, want %s", got, gjson.GetBytes(firstRequest, "input").Raw) + } + if got := gjson.GetBytes(payloads[1], "input").Raw; got != gjson.GetBytes(compactedRequest, "input").Raw { + t.Fatalf("compacted passthrough input = %s, want %s", got, gjson.GetBytes(compactedRequest, "input").Raw) + } + if got := gjson.GetBytes(payloads[1], "model").String(); got != "test-model" { + t.Fatalf("compacted passthrough model = %s, want test-model", got) + } + if bytes.Contains(payloads[1], []byte(`"content":"first"`)) || bytes.Contains(payloads[1], []byte(`"id":"out-1"`)) { + t.Fatalf("compacted passthrough payload contains stale transcript state: %s", payloads[1]) + } + authIDs := executor.AuthIDs() + if len(authIDs) != 2 || authIDs[0] != "auth-ws" || authIDs[1] != "auth-ws" { + t.Fatalf("passthrough auth IDs = %v, want [auth-ws auth-ws]", authIDs) + } +} + func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { manager := coreauth.NewManager(nil, nil, nil) auth := &coreauth.Auth{