mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-31 20:02:36 +08:00
fix(auth): strip "generate" from payload during WebSocket HTTP fallback
- Added `sanitizeDownstreamWebsocketFallbackRequest` to clean `generate` from payload for HTTP fallback requests. - Implemented tests to validate payload handling logic in WebSocket-to-HTTP transitions. Closes: #3556
This commit is contained in:
@@ -77,6 +77,12 @@ type websocketPinnedFailoverExecutor struct {
|
||||
payloads map[string][][]byte
|
||||
}
|
||||
|
||||
type websocketBootstrapFallbackExecutor struct {
|
||||
mu sync.Mutex
|
||||
authIDs []string
|
||||
payloads map[string][][]byte
|
||||
}
|
||||
|
||||
type websocketPinnedFailoverStatusError struct {
|
||||
status int
|
||||
msg string
|
||||
@@ -86,6 +92,70 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
|
||||
|
||||
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
authID := ""
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
if e.payloads == nil {
|
||||
e.payloads = make(map[string][][]byte)
|
||||
}
|
||||
e.authIDs = append(e.authIDs, authID)
|
||||
e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
|
||||
e.mu.Unlock()
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
if authID == "auth-ws" {
|
||||
chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
|
||||
status: http.StatusServiceUnavailable,
|
||||
msg: `{"error":{"message":"websocket bootstrap failed","type":"server_error","code":"ws_failed"}}`,
|
||||
}}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-http","output":[{"type":"message","id":"out-http"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return append([]string(nil), e.authIDs...)
|
||||
}
|
||||
|
||||
func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
src := e.payloads[authID]
|
||||
out := make([][]byte, len(src))
|
||||
for i := range src {
|
||||
out[i] = bytes.Clone(src[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type websocketUpstreamDisconnectExecutor struct {
|
||||
mu sync.Mutex
|
||||
subscribed chan string
|
||||
@@ -1340,6 +1410,87 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketStripsGenerateWhenWebsocketAttemptFallsBackToHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
selector := &orderedWebsocketSelector{order: []string{"auth-ws", "auth-http"}}
|
||||
executor := &websocketBootstrapFallbackExecutor{}
|
||||
manager := coreauth.NewManager(nil, selector, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
authWS := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: executor.Identifier(),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), authWS); err != nil {
|
||||
t.Fatalf("Register websocket auth: %v", err)
|
||||
}
|
||||
authHTTP := &coreauth.Auth{ID: "auth-http", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), authHTTP); err != nil {
|
||||
t.Fatalf("Register HTTP auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(authHTTP.ID, authHTTP.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(authHTTP.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
request := `{"type":"response.create","model":"test-model","generate":false,"input":[{"type":"message","id":"msg-1"}]}`
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(request)); errWrite != nil {
|
||||
t.Fatalf("write websocket message: %v", errWrite)
|
||||
}
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message: %v", errReadMessage)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||
t.Fatalf("payload type = %s, want %s: %s", got, wsEventTypeCompleted, payload)
|
||||
}
|
||||
|
||||
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-ws" || got[1] != "auth-http" {
|
||||
t.Fatalf("selected auth IDs = %v, want [auth-ws auth-http]", got)
|
||||
}
|
||||
|
||||
wsPayloads := executor.Payloads("auth-ws")
|
||||
if len(wsPayloads) != 1 {
|
||||
t.Fatalf("auth-ws payload count = %d, want 1", len(wsPayloads))
|
||||
}
|
||||
if !gjson.GetBytes(wsPayloads[0], "generate").Exists() {
|
||||
t.Fatalf("websocket attempt payload unexpectedly stripped generate: %s", wsPayloads[0])
|
||||
}
|
||||
|
||||
httpPayloads := executor.Payloads("auth-http")
|
||||
if len(httpPayloads) != 1 {
|
||||
t.Fatalf("auth-http payload count = %d, want 1", len(httpPayloads))
|
||||
}
|
||||
if gjson.GetBytes(httpPayloads[0], "generate").Exists() {
|
||||
t.Fatalf("generate leaked after HTTP fallback: %s", httpPayloads[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ProviderExecutor defines the contract required by Manager to execute provider calls.
|
||||
@@ -1581,7 +1582,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
lastErr = errPrepare
|
||||
continue
|
||||
}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
|
||||
execReq := sanitizeDownstreamWebsocketFallbackRequest(execCtx, auth, req)
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, execReq, opts, routeModel, models, pooled)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
@@ -1599,6 +1601,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeDownstreamWebsocketFallbackRequest(ctx context.Context, auth *Auth, req cliproxyexecutor.Request) cliproxyexecutor.Request {
|
||||
if !cliproxyexecutor.DownstreamWebsocket(ctx) || authWebsocketsEnabled(auth) || len(req.Payload) == 0 {
|
||||
return req
|
||||
}
|
||||
updated, errDelete := sjson.DeleteBytes(req.Payload, "generate")
|
||||
if errDelete != nil {
|
||||
return req
|
||||
}
|
||||
req.Payload = updated
|
||||
return req
|
||||
}
|
||||
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
|
||||
Reference in New Issue
Block a user