From 430e679e2a603294248d9ff90e97fa4fe8e88090 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 30 May 2026 05:14:05 +0800 Subject: [PATCH] fix(auth): strip "generate" from payload during WebSocket HTTP fallback - Added `sanitizeDownstreamWebsocketFallbackRequest` to clean `generate` from payload for HTTP fallback requests. - Implemented tests to validate payload handling logic in WebSocket-to-HTTP transitions. Closes: #3556 --- .../openai/openai_responses_websocket_test.go | 151 ++++++++++++++++++ sdk/cliproxy/auth/conductor.go | 16 +- 2 files changed, 166 insertions(+), 1 deletion(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 9f23af82d..6502ae0c8 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -77,6 +77,12 @@ type websocketPinnedFailoverExecutor struct { payloads map[string][][]byte } +type websocketBootstrapFallbackExecutor struct { + mu sync.Mutex + authIDs []string + payloads map[string][][]byte +} + type websocketPinnedFailoverStatusError struct { status int msg string @@ -86,6 +92,70 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg } func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status } +func (e *websocketBootstrapFallbackExecutor) Identifier() string { return "test-provider" } + +func (e *websocketBootstrapFallbackExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketBootstrapFallbackExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + authID := "" + if auth != nil { + authID = auth.ID + } + + e.mu.Lock() + if e.payloads == nil { + e.payloads = make(map[string][][]byte) + } + e.authIDs = append(e.authIDs, authID) + e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload)) + e.mu.Unlock() + + chunks := make(chan coreexecutor.StreamChunk, 1) + if authID == "auth-ws" { + chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{ + status: http.StatusServiceUnavailable, + msg: `{"error":{"message":"websocket bootstrap failed","type":"server_error","code":"ws_failed"}}`, + }} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil + } + + chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-http","output":[{"type":"message","id":"out-http"}]}}`)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketBootstrapFallbackExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketBootstrapFallbackExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketBootstrapFallbackExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketBootstrapFallbackExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + +func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte { + e.mu.Lock() + defer e.mu.Unlock() + src := e.payloads[authID] + out := make([][]byte, len(src)) + for i := range src { + out[i] = bytes.Clone(src[i]) + } + return out +} + type websocketUpstreamDisconnectExecutor struct { mu sync.Mutex subscribed chan string @@ -1340,6 +1410,87 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { } } +func TestResponsesWebsocketStripsGenerateWhenWebsocketAttemptFallsBackToHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + selector := &orderedWebsocketSelector{order: []string{"auth-ws", "auth-http"}} + executor := &websocketBootstrapFallbackExecutor{} + manager := coreauth.NewManager(nil, selector, nil) + manager.RegisterExecutor(executor) + + authWS := &coreauth.Auth{ + ID: "auth-ws", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), authWS); err != nil { + t.Fatalf("Register websocket auth: %v", err) + } + authHTTP := &coreauth.Auth{ID: "auth-http", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), authHTTP); err != nil { + t.Fatalf("Register HTTP auth: %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(authHTTP.ID, authHTTP.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(authWS.ID) + registry.GetGlobalRegistry().UnregisterClient(authHTTP.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + request := `{"type":"response.create","model":"test-model","generate":false,"input":[{"type":"message","id":"msg-1"}]}` + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(request)); errWrite != nil { + t.Fatalf("write websocket message: %v", errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message: %v", errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("payload type = %s, want %s: %s", got, wsEventTypeCompleted, payload) + } + + if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-ws" || got[1] != "auth-http" { + t.Fatalf("selected auth IDs = %v, want [auth-ws auth-http]", got) + } + + wsPayloads := executor.Payloads("auth-ws") + if len(wsPayloads) != 1 { + t.Fatalf("auth-ws payload count = %d, want 1", len(wsPayloads)) + } + if !gjson.GetBytes(wsPayloads[0], "generate").Exists() { + t.Fatalf("websocket attempt payload unexpectedly stripped generate: %s", wsPayloads[0]) + } + + httpPayloads := executor.Payloads("auth-http") + if len(httpPayloads) != 1 { + t.Fatalf("auth-http payload count = %d, want 1", len(httpPayloads)) + } + if gjson.GetBytes(httpPayloads[0], "generate").Exists() { + t.Fatalf("generate leaked after HTTP fallback: %s", httpPayloads[0]) + } +} + func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 5413dcf4b..33116fba8 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -25,6 +25,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage" log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" ) // ProviderExecutor defines the contract required by Manager to execute provider calls. @@ -1581,7 +1582,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string lastErr = errPrepare continue } - streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled) + execReq := sanitizeDownstreamWebsocketFallbackRequest(execCtx, auth, req) + streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, execReq, opts, routeModel, models, pooled) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx @@ -1599,6 +1601,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string } } +func sanitizeDownstreamWebsocketFallbackRequest(ctx context.Context, auth *Auth, req cliproxyexecutor.Request) cliproxyexecutor.Request { + if !cliproxyexecutor.DownstreamWebsocket(ctx) || authWebsocketsEnabled(auth) || len(req.Payload) == 0 { + return req + } + updated, errDelete := sjson.DeleteBytes(req.Payload, "generate") + if errDelete != nil { + return req + } + req.Payload = updated + return req +} + func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options { requestedModel = strings.TrimSpace(requestedModel) if requestedModel == "" {