From ea90ab6f775f3ef834602e7aed5ed91bc3477b3b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 15 Jun 2026 08:22:07 +0800 Subject: [PATCH] feat(websockets): implement XAIWebsocketsExecutor with enhanced execution and ID mapping - Developed `XAIWebsocketsExecutor` for handling xAI Responses via WebSocket transport. - Introduced session and state management with `codexWebsocketSessionStore` and `xaiWebsocketIDStateStore`. - Added robust ID mapping for upstream and downstream request/response sequences. - Enhanced error propagation and handling of WebSocket terminal events. - Included utility methods for WebSocket request preparation, connection management, and state tracking. - Added foundational support for compact and streamed responses via enhanced session tracking. --- .../executor/xai_websockets_executor.go | 1241 +++++++++++++++++ .../executor/xai_websockets_executor_test.go | 425 ++++++ .../openai/openai_responses_websocket.go | 56 +- .../openai/openai_responses_websocket_test.go | 158 ++- sdk/cliproxy/auth/scheduler.go | 11 +- sdk/cliproxy/auth/scheduler_test.go | 26 + sdk/cliproxy/service.go | 5 +- .../service_codex_executor_binding_test.go | 23 + 8 files changed, 1925 insertions(+), 20 deletions(-) create mode 100644 internal/runtime/executor/xai_websockets_executor.go create mode 100644 internal/runtime/executor/xai_websockets_executor_test.go diff --git a/internal/runtime/executor/xai_websockets_executor.go b/internal/runtime/executor/xai_websockets_executor.go new file mode 100644 index 000000000..4102ce08a --- /dev/null +++ b/internal/runtime/executor/xai_websockets_executor.go @@ -0,0 +1,1241 @@ +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements an xAI executor that uses the Responses API WebSocket transport. +package executor + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// XAIWebsocketsExecutor executes xAI Responses requests using a WebSocket transport. +type XAIWebsocketsExecutor struct { + *XAIExecutor + + store *codexWebsocketSessionStore + idStore *xaiWebsocketIDStateStore +} + +var globalXAIWebsocketSessionStore = &codexWebsocketSessionStore{ + sessions: make(map[string]*codexWebsocketSession), +} + +var globalXAIWebsocketIDStates = &xaiWebsocketIDStateStore{ + sessions: make(map[string]*xaiWebsocketIDState), +} + +type xaiWebsocketIDStateStore struct { + mu sync.Mutex + sessions map[string]*xaiWebsocketIDState +} + +type xaiWebsocketIDState struct { + mu sync.Mutex + downstreamToUpstream map[string]string + sequence int +} + +type xaiWebsocketRequestIDMapper struct { + state *xaiWebsocketIDState + downstreamPreviousID string + upstreamPreviousID string + upstreamResponseID string + downstreamResponseID string +} + +func NewXAIWebsocketsExecutor(cfg *config.Config) *XAIWebsocketsExecutor { + return &XAIWebsocketsExecutor{ + XAIExecutor: NewXAIExecutor(cfg), + store: globalXAIWebsocketSessionStore, + idStore: globalXAIWebsocketIDStates, + } +} + +func getXAIWebsocketIDState(store *xaiWebsocketIDStateStore, sessionID string) *xaiWebsocketIDState { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" || store == nil { + return nil + } + store.mu.Lock() + defer store.mu.Unlock() + if store.sessions == nil { + store.sessions = make(map[string]*xaiWebsocketIDState) + } + if state := store.sessions[sessionID]; state != nil { + return state + } + state := &xaiWebsocketIDState{ + downstreamToUpstream: make(map[string]string), + } + store.sessions[sessionID] = state + return state +} + +func deleteXAIWebsocketIDState(store *xaiWebsocketIDStateStore, sessionID string) { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" || store == nil { + return + } + store.mu.Lock() + delete(store.sessions, sessionID) + store.mu.Unlock() +} + +func newXAIWebsocketRequestIDMapper(store *xaiWebsocketIDStateStore, sessionID string, downstreamRequest []byte) *xaiWebsocketRequestIDMapper { + state := getXAIWebsocketIDState(store, sessionID) + if state == nil { + return nil + } + downstreamPreviousID := strings.TrimSpace(gjson.GetBytes(downstreamRequest, "previous_response_id").String()) + upstreamPreviousID := downstreamPreviousID + if downstreamPreviousID != "" { + upstreamPreviousID = state.upstreamIDForDownstream(downstreamPreviousID) + } + return &xaiWebsocketRequestIDMapper{ + state: state, + downstreamPreviousID: downstreamPreviousID, + upstreamPreviousID: upstreamPreviousID, + } +} + +func (s *xaiWebsocketIDState) upstreamIDForDownstream(downstreamID string) string { + downstreamID = strings.TrimSpace(downstreamID) + if s == nil || downstreamID == "" { + return downstreamID + } + s.mu.Lock() + defer s.mu.Unlock() + if upstreamID := strings.TrimSpace(s.downstreamToUpstream[downstreamID]); upstreamID != "" { + return upstreamID + } + return downstreamID +} + +func (m *xaiWebsocketRequestIDMapper) upstreamRequestPayload(payload []byte) []byte { + if m == nil || len(payload) == 0 || m.downstreamPreviousID == m.upstreamPreviousID { + return payload + } + if m.upstreamPreviousID == "" { + out, errDelete := sjson.DeleteBytes(payload, "previous_response_id") + if errDelete == nil { + return out + } + return payload + } + out, errSet := sjson.SetBytes(payload, "previous_response_id", m.upstreamPreviousID) + if errSet != nil { + return payload + } + return out +} + +func (m *xaiWebsocketRequestIDMapper) downstreamResponsePayload(payload []byte) []byte { + if m == nil || len(payload) == 0 { + return payload + } + upstreamResponseID := strings.TrimSpace(gjson.GetBytes(payload, "response.id").String()) + downstreamResponseID := m.downstreamIDForUpstreamResponse(upstreamResponseID) + if downstreamResponseID == "" { + return payload + } + return rewriteXAIWebsocketDownstreamIDs(payload, m.upstreamResponseID, downstreamResponseID, m.upstreamPreviousID, m.downstreamPreviousID) +} + +func (m *xaiWebsocketRequestIDMapper) downstreamIDForUpstreamResponse(upstreamResponseID string) string { + upstreamResponseID = strings.TrimSpace(upstreamResponseID) + if m == nil || m.state == nil { + return upstreamResponseID + } + if m.upstreamResponseID != "" { + return m.downstreamResponseID + } + if upstreamResponseID == "" { + return "" + } + + m.state.mu.Lock() + defer m.state.mu.Unlock() + m.upstreamResponseID = upstreamResponseID + m.downstreamResponseID = upstreamResponseID + if m.downstreamPreviousID != "" && m.upstreamPreviousID != "" && upstreamResponseID == m.upstreamPreviousID { + m.state.sequence++ + m.downstreamResponseID = fmt.Sprintf("%s-xai-%d", upstreamResponseID, m.state.sequence) + } + if m.state.downstreamToUpstream == nil { + m.state.downstreamToUpstream = make(map[string]string) + } + m.state.downstreamToUpstream[upstreamResponseID] = upstreamResponseID + m.state.downstreamToUpstream[m.downstreamResponseID] = upstreamResponseID + return m.downstreamResponseID +} + +func rewriteXAIWebsocketDownstreamIDs(payload []byte, upstreamResponseID string, downstreamResponseID string, upstreamPreviousID string, downstreamPreviousID string) []byte { + upstreamResponseID = strings.TrimSpace(upstreamResponseID) + downstreamResponseID = strings.TrimSpace(downstreamResponseID) + upstreamPreviousID = strings.TrimSpace(upstreamPreviousID) + downstreamPreviousID = strings.TrimSpace(downstreamPreviousID) + if len(payload) == 0 || (upstreamResponseID == downstreamResponseID && upstreamPreviousID == downstreamPreviousID) { + return payload + } + + var value any + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.UseNumber() + if errDecode := decoder.Decode(&value); errDecode != nil { + return payload + } + if !rewriteXAIWebsocketDownstreamIDValue(value, upstreamResponseID, downstreamResponseID, upstreamPreviousID, downstreamPreviousID, "") { + return payload + } + out, errMarshal := json.Marshal(value) + if errMarshal != nil { + return payload + } + return out +} + +func rewriteXAIWebsocketDownstreamIDValue(value any, upstreamResponseID string, downstreamResponseID string, upstreamPreviousID string, downstreamPreviousID string, key string) bool { + switch typed := value.(type) { + case map[string]any: + changed := false + for childKey, childValue := range typed { + if childString, ok := childValue.(string); ok { + replaced := rewriteXAIWebsocketDownstreamIDString(childString, childKey, upstreamResponseID, downstreamResponseID, upstreamPreviousID, downstreamPreviousID) + if replaced != childString { + typed[childKey] = replaced + changed = true + } + continue + } + if rewriteXAIWebsocketDownstreamIDValue(childValue, upstreamResponseID, downstreamResponseID, upstreamPreviousID, downstreamPreviousID, childKey) { + changed = true + } + } + return changed + case []any: + changed := false + for i := range typed { + if rewriteXAIWebsocketDownstreamIDValue(typed[i], upstreamResponseID, downstreamResponseID, upstreamPreviousID, downstreamPreviousID, key) { + changed = true + } + } + return changed + default: + return false + } +} + +func rewriteXAIWebsocketDownstreamIDString(value string, key string, upstreamResponseID string, downstreamResponseID string, upstreamPreviousID string, downstreamPreviousID string) string { + switch key { + case "id", "item_id": + if upstreamResponseID != "" && downstreamResponseID != "" && downstreamResponseID != upstreamResponseID && strings.Contains(value, upstreamResponseID) { + return strings.ReplaceAll(value, upstreamResponseID, downstreamResponseID) + } + case "previous_response_id": + if upstreamPreviousID != "" && downstreamPreviousID != "" && value == upstreamPreviousID { + return downstreamPreviousID + } + } + return value +} + +func (e *XAIWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if e == nil || e.XAIExecutor == nil { + return cliproxyexecutor.Response{}, fmt.Errorf("xai websockets executor: executor is nil") + } + return e.XAIExecutor.Execute(ctx, auth, req, opts) +} + +func (e *XAIWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + if e == nil || e.XAIExecutor == nil { + return nil, fmt.Errorf("xai websockets executor: executor is nil") + } + if ctx == nil { + ctx = context.Background() + } + if opts.Alt == "responses/compact" { + return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"} + } + if xaiInputHasItemType(req.Payload, "compaction_trigger") { + return e.XAIExecutor.ExecuteStream(ctx, auth, req, opts) + } + + executionSessionID := executionSessionIDFromOptions(opts) + idMapper := newXAIWebsocketRequestIDMapper(e.idStore, executionSessionID, req.Payload) + token, baseURL := xaiCreds(auth) + if baseURL == "" { + baseURL = xaiauth.DefaultAPIBaseURL + } + + prepared, err := e.prepareResponsesWebsocketRequest(ctx, req, opts) + if err != nil { + return nil, err + } + if idMapper != nil { + prepared.body = idMapper.upstreamRequestPayload(prepared.body) + } + + reporter := helps.NewExecutorUsageReporter(ctx, e, prepared.baseModel, auth) + defer reporter.TrackFailure(ctx, &err) + reporter.SetTranslatedReasoningEffort(prepared.body, e.Identifier()) + + httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" + wsURL, err := buildXAIResponsesWebsocketURL(httpURL) + if err != nil { + return nil, err + } + wsHeaders := applyXAIWebsocketHeaders(http.Header{}, auth, token, prepared.sessionID) + wsReqBody := buildXAIWebsocketRequestBody(prepared.body) + warmupRequest := xaiWebsocketGenerateFalse(wsReqBody) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + + var sess *codexWebsocketSession + if executionSessionID != "" { + sess = e.getOrCreateSession(executionSessionID) + if sess != nil { + sess.reqMu.Lock() + } + } + + wsReqLog := helps.UpstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBody, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + } + helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog) + logXAIWebsocketRequest(executionSessionID, authID, wsURL, wsReqBody) + + conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + var upstreamHeaders http.Header + if respHS != nil { + upstreamHeaders = respHS.Header.Clone() + } + if errDial != nil { + bodyErr := websocketHandshakeBody(respHS) + if respHS != nil { + helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr) + } + if respHS != nil && respHS.StatusCode > 0 { + if sess != nil { + sess.reqMu.Unlock() + } + return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} + } + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial) + if sess != nil { + sess.reqMu.Unlock() + } + return nil, errDial + } + recordAPIWebsocketHandshake(ctx, e.cfg, respHS) + reporter.StartResponseTTFT() + + if sess == nil { + logXAIWebsocketConnected(executionSessionID, authID, wsURL) + } + + var readCh chan codexWebsocketRead + if sess != nil { + readCh = make(chan codexWebsocketRead, 4096) + sess.setActive(readCh) + } + + if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { + helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend) + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "send_error", errSend) + connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + if errDialRetry != nil || connRetry == nil { + closeHTTPResponseBody(respHSRetry, "xai websockets executor: close handshake response body error") + helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry) + sess.clearActive(readCh) + sess.reqMu.Unlock() + return nil, errDialRetry + } + wsReqBodyRetry := buildXAIWebsocketRequestBody(prepared.body) + helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBodyRetry, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + logXAIWebsocketRequest(executionSessionID, authID, wsURL, wsReqBodyRetry) + recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry) + reporter.StartResponseTTFT() + if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { + helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry) + e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) + sess.clearActive(readCh) + sess.reqMu.Unlock() + return nil, errSendRetry + } + conn = connRetry + wsReqBody = wsReqBodyRetry + } else { + logXAIWebsocketDisconnected(executionSessionID, authID, wsURL, "send_error", errSend) + if errClose := conn.Close(); errClose != nil { + log.Errorf("xai websockets executor: close websocket error: %v", errClose) + } + return nil, errSend + } + } + + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + terminateReason := "completed" + var terminateErr error + + defer close(out) + defer func() { + if sess != nil { + sess.clearActive(readCh) + sess.reqMu.Unlock() + return + } + logXAIWebsocketDisconnected(executionSessionID, authID, wsURL, terminateReason, terminateErr) + if errClose := conn.Close(); errClose != nil { + log.Errorf("xai websockets executor: close websocket error: %v", errClose) + } + }() + + send := func(chunk cliproxyexecutor.StreamChunk) bool { + if ctx == nil { + out <- chunk + return true + } + select { + case out <- chunk: + return true + case <-ctx.Done(): + return false + } + } + + var param any + outputItemsByIndex := make(map[int64][]byte) + var outputItemsFallback [][]byte + for { + if ctx != nil && ctx.Err() != nil { + terminateReason = "context_done" + terminateErr = ctx.Err() + _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) + return + } + msgType, payload, errRead := readXAIWebsocketMessage(ctx, sess, conn, readCh) + if errRead != nil { + if sess != nil && ctx != nil && ctx.Err() != nil { + terminateReason = "context_done" + terminateErr = ctx.Err() + _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) + return + } + terminateReason = "read_error" + terminateErr = errRead + helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead) + reporter.PublishFailure(ctx, errRead) + _ = send(cliproxyexecutor.StreamChunk{Err: errRead}) + return + } + if msgType != websocket.TextMessage { + if msgType == websocket.BinaryMessage { + errBinary := fmt.Errorf("xai websockets executor: unexpected binary message") + terminateReason = "unexpected_binary" + terminateErr = errBinary + helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", errBinary) + reporter.PublishFailure(ctx, errBinary) + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) + } + _ = send(cliproxyexecutor.StreamChunk{Err: errBinary}) + return + } + continue + } + + payload = bytes.TrimSpace(payload) + if len(payload) == 0 { + continue + } + reporter.MarkFirstResponseByte() + helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload) + + if wsErr, ok := parseXAIWebsocketError(payload); ok { + terminateReason = "upstream_error" + terminateErr = wsErr + helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr) + reporter.PublishFailure(ctx, wsErr) + if sess != nil { + e.invalidateUpstreamConnWithoutDisconnectNotify(sess, conn, "upstream_error", wsErr) + } + _ = send(cliproxyexecutor.StreamChunk{Err: wsErr}) + return + } + + eventType := gjson.GetBytes(payload, "type").String() + isTerminalEvent := eventType == "response.completed" || eventType == "response.done" || eventType == "error" + warmupCompletedPayload := []byte(nil) + switch eventType { + case "response.created": + if warmupRequest { + warmupCompletedPayload = buildXAIWebsocketWarmupCompletedPayload(payload) + logXAIWebsocketWarmupCompleted(executionSessionID, authID, wsURL, payload) + } + case "response.output_item.done": + xaiCollectOutputItemDone(payload, outputItemsByIndex, &outputItemsFallback) + case "response.completed": + logXAIWebsocketTerminalResponse(executionSessionID, authID, wsURL, eventType, payload) + if detail, ok := helps.ParseCodexUsage(payload); ok { + reporter.Publish(ctx, detail) + } + payload = xaiPatchCompletedOutput(payload, outputItemsByIndex, outputItemsFallback) + case "response.done": + logXAIWebsocketTerminalResponse(executionSessionID, authID, wsURL, eventType, payload) + if detail, ok := helps.ParseCodexUsage(payload); ok { + reporter.Publish(ctx, detail) + } + } + + if cliproxyexecutor.DownstreamWebsocket(ctx) { + downstreamPayload := payload + downstreamWarmupCompletedPayload := warmupCompletedPayload + if idMapper != nil { + downstreamPayload = idMapper.downstreamResponsePayload(payload) + if len(warmupCompletedPayload) > 0 { + downstreamWarmupCompletedPayload = idMapper.downstreamResponsePayload(warmupCompletedPayload) + } + } + if !send(cliproxyexecutor.StreamChunk{Payload: downstreamPayload}) { + terminateReason = "context_done" + terminateErr = ctx.Err() + return + } + if len(downstreamWarmupCompletedPayload) > 0 { + if !send(cliproxyexecutor.StreamChunk{Payload: downstreamWarmupCompletedPayload}) { + terminateReason = "context_done" + terminateErr = ctx.Err() + return + } + return + } + if isTerminalEvent { + return + } + continue + } + + payload = normalizeCodexWebsocketCompletion(payload) + line := encodeCodexWebsocketAsSSE(payload) + chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.responseFormat, req.Model, prepared.originalPayload, prepared.body, line, ¶m) + for i := range chunks { + if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) { + terminateReason = "context_done" + terminateErr = ctx.Err() + return + } + } + if len(warmupCompletedPayload) > 0 { + line = encodeCodexWebsocketAsSSE(warmupCompletedPayload) + chunks = sdktranslator.TranslateStream(ctx, prepared.to, prepared.responseFormat, req.Model, prepared.originalPayload, prepared.body, line, ¶m) + for i := range chunks { + if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) { + terminateReason = "context_done" + terminateErr = ctx.Err() + return + } + } + return + } + if eventType == "response.completed" || eventType == "response.done" { + return + } + } + }() + return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil +} + +func xaiWebsocketGenerateFalse(payload []byte) bool { + generate := gjson.GetBytes(payload, "generate") + return generate.Exists() && !generate.Bool() +} + +func buildXAIWebsocketWarmupCompletedPayload(createdPayload []byte) []byte { + completed := []byte(`{"type":"response.completed","response":{"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`) + if sequence := gjson.GetBytes(createdPayload, "sequence_number"); sequence.Exists() { + completed, _ = sjson.SetBytes(completed, "sequence_number", sequence.Int()+1) + } + if response := gjson.GetBytes(createdPayload, "response"); response.Exists() && response.IsObject() { + responsePayload := []byte(response.Raw) + responsePayload, _ = sjson.SetBytes(responsePayload, "status", "completed") + if !gjson.GetBytes(responsePayload, "output").Exists() { + responsePayload, _ = sjson.SetRawBytes(responsePayload, "output", []byte("[]")) + } + if !gjson.GetBytes(responsePayload, "usage").Exists() { + responsePayload, _ = sjson.SetRawBytes(responsePayload, "usage", []byte(`{"input_tokens":0,"output_tokens":0,"total_tokens":0}`)) + } + completed, _ = sjson.SetRawBytes(completed, "response", responsePayload) + } + return completed +} + +func parseXAIWebsocketError(payload []byte) (error, bool) { + if wsErr, ok := parseCodexWebsocketError(payload); ok { + return wsErr, true + } + if len(payload) == 0 || !gjson.GetBytes(payload, "error").Exists() { + return nil, false + } + status := int(gjson.GetBytes(payload, "status").Int()) + if status <= 0 { + status = int(gjson.GetBytes(payload, "status_code").Int()) + } + if status <= 0 { + status = xaiBareWebsocketErrorStatus(payload) + } + out := []byte(`{}`) + out, _ = sjson.SetBytes(out, "type", "error") + out, _ = sjson.SetBytes(out, "status", status) + if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() { + out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw)) + } + return statusErr{code: status, msg: string(out)}, true +} + +func xaiBareWebsocketErrorStatus(payload []byte) int { + for _, path := range []string{"error.code", "error.status", "code"} { + raw := strings.TrimSpace(gjson.GetBytes(payload, path).String()) + if raw == "" { + continue + } + status, errAtoi := strconv.Atoi(raw) + if errAtoi == nil && status > 0 { + return status + } + } + message := strings.TrimSpace(gjson.GetBytes(payload, "error.message").String()) + if strings.Contains(message, `"code":"400"`) || strings.Contains(message, "Request validation error") { + return http.StatusBadRequest + } + return http.StatusInternalServerError +} + +func (e *XAIWebsocketsExecutor) prepareResponsesWebsocketRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*xaiPreparedRequest, error) { + prepared, err := e.prepareResponsesRequest(ctx, req, opts, true) + if err != nil { + return nil, err + } + if previousResponseID := strings.TrimSpace(gjson.GetBytes(req.Payload, "previous_response_id").String()); previousResponseID != "" { + prepared.body, _ = sjson.SetBytes(prepared.body, "previous_response_id", previousResponseID) + } + return prepared, nil +} + +func (e *XAIWebsocketsExecutor) dialXAIWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + dialer := newProxyAwareWebsocketDialer(e.cfg, auth) + dialer.HandshakeTimeout = codexResponsesWebsocketHandshakeTO + dialer.EnableCompression = true + if ctx == nil { + ctx = context.Background() + } + conn, resp, err := dialer.DialContext(ctx, wsURL, headers) + if conn != nil { + // Avoid gorilla/websocket flate tail validation issues on some upstreams/Go versions. + conn.EnableWriteCompression(false) + } + return conn, resp, err +} + +func (e *XAIWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWebsocketSession { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" || e == nil { + return nil + } + store := e.store + if store == nil { + store = globalXAIWebsocketSessionStore + } + store.mu.Lock() + defer store.mu.Unlock() + if store.sessions == nil { + store.sessions = make(map[string]*codexWebsocketSession) + } + if sess, ok := store.sessions[sessionID]; ok && sess != nil { + return sess + } + sess := &codexWebsocketSession{ + sessionID: sessionID, + upstreamDisconnectCh: make(chan error, 1), + } + store.sessions[sessionID] = sess + return sess +} + +func (e *XAIWebsocketsExecutor) UpstreamDisconnectChan(sessionID string) <-chan error { + sess := e.getOrCreateSession(sessionID) + if sess == nil { + return nil + } + return sess.upstreamDisconnectCh +} + +func (e *XAIWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + if sess == nil { + return e.dialXAIWebsocket(ctx, auth, wsURL, headers) + } + + sess.connMu.Lock() + conn := sess.conn + readerConn := sess.readerConn + sess.connMu.Unlock() + if conn != nil { + if readerConn != conn { + sess.connMu.Lock() + sess.readerConn = conn + sess.connMu.Unlock() + configureXAIWebsocketConn(sess, conn) + go e.readUpstreamLoop(sess, conn) + } + return conn, nil, nil + } + + conn, resp, errDial := e.dialXAIWebsocket(ctx, auth, wsURL, headers) + if errDial != nil { + return nil, resp, errDial + } + + sess.connMu.Lock() + if sess.conn != nil { + previous := sess.conn + sess.connMu.Unlock() + if errClose := conn.Close(); errClose != nil { + log.Errorf("xai websockets executor: close websocket error: %v", errClose) + } + return previous, nil, nil + } + sess.conn = conn + sess.wsURL = wsURL + sess.authID = authID + sess.readerConn = conn + sess.connMu.Unlock() + + configureXAIWebsocketConn(sess, conn) + go e.readUpstreamLoop(sess, conn) + logXAIWebsocketConnected(sess.sessionID, authID, wsURL) + return conn, resp, nil +} + +func configureXAIWebsocketConn(sess *codexWebsocketSession, conn *websocket.Conn) { + if sess == nil || conn == nil { + return + } + conn.SetPingHandler(func(appData string) error { + sess.writeMu.Lock() + defer sess.writeMu.Unlock() + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Time{}) + }) +} + +func readXAIWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, conn *websocket.Conn, readCh chan codexWebsocketRead) (int, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + if sess == nil { + if conn == nil { + return 0, nil, fmt.Errorf("xai websockets executor: websocket conn is nil") + } + msgType, payload, errRead := conn.ReadMessage() + return msgType, payload, errRead + } + if conn == nil { + return 0, nil, fmt.Errorf("xai websockets executor: websocket conn is nil") + } + if readCh == nil { + return 0, nil, fmt.Errorf("xai websockets executor: session read channel is nil") + } + for { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case ev, ok := <-readCh: + if !ok { + return 0, nil, fmt.Errorf("xai websockets executor: session read channel closed") + } + if ev.conn != conn { + continue + } + if ev.err != nil { + return 0, nil, ev.err + } + return ev.msgType, ev.payload, nil + } + } +} + +func (e *XAIWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, conn *websocket.Conn) { + if e == nil || sess == nil || conn == nil { + return + } + for { + msgType, payload, errRead := conn.ReadMessage() + if errRead != nil { + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch != nil { + select { + case ch <- codexWebsocketRead{conn: conn, err: errRead}: + case <-done: + default: + } + sess.clearActive(ch) + close(ch) + } + e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) + return + } + + if msgType != websocket.TextMessage { + if msgType == websocket.BinaryMessage { + errBinary := fmt.Errorf("xai websockets executor: unexpected binary message") + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch != nil { + select { + case ch <- codexWebsocketRead{conn: conn, err: errBinary}: + case <-done: + default: + } + sess.clearActive(ch) + close(ch) + } + e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) + return + } + continue + } + + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch == nil { + continue + } + select { + case ch <- codexWebsocketRead{conn: conn, msgType: msgType, payload: payload}: + case <-done: + } + } +} + +func (e *XAIWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) { + e.invalidateUpstreamConnWithNotify(sess, conn, reason, err, true) +} + +func (e *XAIWebsocketsExecutor) invalidateUpstreamConnWithoutDisconnectNotify(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) { + e.invalidateUpstreamConnWithNotify(sess, conn, reason, err, false) +} + +func (e *XAIWebsocketsExecutor) invalidateUpstreamConnWithNotify(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error, notify bool) { + if sess == nil || conn == nil { + return + } + + sess.connMu.Lock() + current := sess.conn + authID := sess.authID + wsURL := sess.wsURL + sessionID := sess.sessionID + if current == nil || current != conn { + sess.connMu.Unlock() + return + } + sess.conn = nil + if sess.readerConn == conn { + sess.readerConn = nil + } + sess.connMu.Unlock() + + logXAIWebsocketDisconnected(sessionID, authID, wsURL, reason, err) + if notify { + sess.notifyUpstreamDisconnect(err) + } + if errClose := conn.Close(); errClose != nil { + log.Errorf("xai websockets executor: close websocket error: %v", errClose) + } +} + +func (e *XAIWebsocketsExecutor) CloseExecutionSession(sessionID string) { + sessionID = strings.TrimSpace(sessionID) + if e == nil || sessionID == "" { + return + } + if sessionID == cliproxyauth.CloseAllExecutionSessionsID { + return + } + + store := e.store + if store == nil { + store = globalXAIWebsocketSessionStore + } + store.mu.Lock() + sess := store.sessions[sessionID] + delete(store.sessions, sessionID) + store.mu.Unlock() + deleteXAIWebsocketIDState(e.idStore, sessionID) + + e.closeExecutionSession(sess, "session_closed") +} + +func (e *XAIWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { + closeXAIWebsocketSession(sess, reason) +} + +func closeXAIWebsocketSession(sess *codexWebsocketSession, reason string) { + if sess == nil { + return + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "session_closed" + } + + sess.connMu.Lock() + conn := sess.conn + authID := sess.authID + wsURL := sess.wsURL + sess.conn = nil + if sess.readerConn == conn { + sess.readerConn = nil + } + sessionID := sess.sessionID + sess.connMu.Unlock() + + if conn == nil { + return + } + logXAIWebsocketDisconnected(sessionID, authID, wsURL, reason, nil) + if errClose := conn.Close(); errClose != nil { + log.Errorf("xai websockets executor: close websocket error: %v", errClose) + } +} + +func buildXAIWebsocketRequestBody(body []byte) []byte { + if len(body) == 0 { + return nil + } + wsReqBody := bytes.Clone(body) + wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.create") + wsReqBody, _ = sjson.DeleteBytes(wsReqBody, "stream") + wsReqBody, _ = sjson.DeleteBytes(wsReqBody, "stream_options") + wsReqBody, _ = sjson.DeleteBytes(wsReqBody, "background") + wsReqBody, _ = sjson.SetBytes(wsReqBody, "store", true) + if strings.TrimSpace(gjson.GetBytes(wsReqBody, "previous_response_id").String()) != "" { + wsReqBody, _ = sjson.DeleteBytes(wsReqBody, "instructions") + } + return wsReqBody +} + +func buildXAIResponsesWebsocketURL(httpURL string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(httpURL)) + if err != nil { + return "", err + } + switch strings.ToLower(parsed.Scheme) { + case "http": + parsed.Scheme = "ws" + case "https": + parsed.Scheme = "wss" + case "ws", "wss": + default: + return "", fmt.Errorf("xai websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme) + } + if strings.TrimSpace(parsed.Host) == "" { + return "", fmt.Errorf("xai websockets executor: responses websocket URL host is empty") + } + return parsed.String(), nil +} + +func applyXAIWebsocketHeaders(headers http.Header, auth *cliproxyauth.Auth, token string, sessionID string) http.Header { + if headers == nil { + headers = http.Header{} + } + headers.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + headers.Set("Authorization", "Bearer "+token) + } + if sessionID != "" { + headers.Set("x-grok-conv-id", sessionID) + } + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(&http.Request{Header: headers}, attrs) + return headers +} + +func logXAIWebsocketConnected(sessionID string, authID string, wsURL string) { + log.Infof("xai websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) +} + +func logXAIWebsocketRequest(sessionID string, authID string, wsURL string, payload []byte) { + if len(payload) == 0 { + log.Infof("xai websockets: upstream request sent session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) + return + } + generateValue := "default" + if generate := gjson.GetBytes(payload, "generate"); generate.Exists() { + generateValue = strings.TrimSpace(generate.Raw) + } + log.Infof( + "xai websockets: upstream request sent session=%s auth=%s url=%s event=%s previous_response_id=%s generate=%s input_items=%d", + strings.TrimSpace(sessionID), + strings.TrimSpace(authID), + strings.TrimSpace(wsURL), + strings.TrimSpace(gjson.GetBytes(payload, "type").String()), + strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()), + generateValue, + len(gjson.GetBytes(payload, "input").Array()), + ) +} + +func logXAIWebsocketWarmupCompleted(sessionID string, authID string, wsURL string, payload []byte) { + log.Infof( + "xai websockets: upstream warmup completed session=%s auth=%s url=%s response_id=%s", + strings.TrimSpace(sessionID), + strings.TrimSpace(authID), + strings.TrimSpace(wsURL), + strings.TrimSpace(gjson.GetBytes(payload, "response.id").String()), + ) +} + +func logXAIWebsocketTerminalResponse(sessionID string, authID string, wsURL string, eventType string, payload []byte) { + log.Infof( + "xai websockets: upstream terminal response session=%s auth=%s url=%s event=%s response_id=%s previous_response_id=%s", + strings.TrimSpace(sessionID), + strings.TrimSpace(authID), + strings.TrimSpace(wsURL), + strings.TrimSpace(eventType), + strings.TrimSpace(gjson.GetBytes(payload, "response.id").String()), + strings.TrimSpace(gjson.GetBytes(payload, "response.previous_response_id").String()), + ) +} + +func logXAIWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { + if err != nil { + log.Infof("xai websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err) + return + } + log.Infof("xai websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) +} + +// CloseXAIWebsocketSessionsForAuthID closes all active xAI upstream websocket sessions +// associated with the supplied auth ID. +func CloseXAIWebsocketSessionsForAuthID(authID string, reason string) { + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "auth_removed" + } + + store := globalXAIWebsocketSessionStore + if store == nil { + return + } + + type sessionItem struct { + sessionID string + sess *codexWebsocketSession + } + + store.mu.Lock() + items := make([]sessionItem, 0, len(store.sessions)) + for sessionID, sess := range store.sessions { + items = append(items, sessionItem{sessionID: sessionID, sess: sess}) + } + store.mu.Unlock() + + matches := make([]sessionItem, 0) + for i := range items { + sess := items[i].sess + if sess == nil { + continue + } + sess.connMu.Lock() + sessAuthID := strings.TrimSpace(sess.authID) + sess.connMu.Unlock() + if sessAuthID == authID { + matches = append(matches, items[i]) + } + } + if len(matches) == 0 { + return + } + + toClose := make([]*codexWebsocketSession, 0, len(matches)) + store.mu.Lock() + for i := range matches { + current, ok := store.sessions[matches[i].sessionID] + if !ok || current == nil || current != matches[i].sess { + continue + } + delete(store.sessions, matches[i].sessionID) + deleteXAIWebsocketIDState(globalXAIWebsocketIDStates, matches[i].sessionID) + toClose = append(toClose, current) + } + store.mu.Unlock() + + for i := range toClose { + closeXAIWebsocketSession(toClose[i], reason) + } +} + +// XAIAutoExecutor routes xAI stream requests to the websocket transport only +// when the downstream transport is websocket and the selected auth enables +// websockets. Non-stream requests keep using the HTTP implementation. +type XAIAutoExecutor struct { + httpExec *XAIExecutor + wsExec *XAIWebsocketsExecutor +} + +func NewXAIAutoExecutor(cfg *config.Config) *XAIAutoExecutor { + return &XAIAutoExecutor{ + httpExec: NewXAIExecutor(cfg), + wsExec: NewXAIWebsocketsExecutor(cfg), + } +} + +func (e *XAIAutoExecutor) Identifier() string { return "xai" } + +func (e *XAIAutoExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if e == nil || e.httpExec == nil { + return nil + } + return e.httpExec.PrepareRequest(req, auth) +} + +func (e *XAIAutoExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if e == nil || e.httpExec == nil { + return nil, fmt.Errorf("xai auto executor: http executor is nil") + } + return e.httpExec.HttpRequest(ctx, auth, req) +} + +func (e *XAIAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if e == nil || e.httpExec == nil { + return cliproxyexecutor.Response{}, fmt.Errorf("xai auto executor: executor is nil") + } + return e.httpExec.Execute(ctx, auth, req, opts) +} + +func (e *XAIAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + if e == nil || e.httpExec == nil || e.wsExec == nil { + return nil, fmt.Errorf("xai auto executor: executor is nil") + } + if cliproxyexecutor.DownstreamWebsocket(ctx) && xaiWebsocketsEnabled(auth) { + return e.wsExec.ExecuteStream(ctx, auth, req, opts) + } + return e.httpExec.ExecuteStream(ctx, auth, req, opts) +} + +func (e *XAIAutoExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if e == nil || e.httpExec == nil { + return nil, fmt.Errorf("xai auto executor: http executor is nil") + } + return e.httpExec.Refresh(ctx, auth) +} + +func (e *XAIAutoExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if e == nil || e.httpExec == nil { + return cliproxyexecutor.Response{}, fmt.Errorf("xai auto executor: http executor is nil") + } + return e.httpExec.CountTokens(ctx, auth, req, opts) +} + +func (e *XAIAutoExecutor) CloseExecutionSession(sessionID string) { + if e == nil || e.wsExec == nil { + return + } + e.wsExec.CloseExecutionSession(sessionID) +} + +func (e *XAIAutoExecutor) UpstreamDisconnectChan(sessionID string) <-chan error { + if e == nil || e.wsExec == nil { + return nil + } + return e.wsExec.UpstreamDisconnectChan(sessionID) +} + +func xaiWebsocketsEnabled(auth *cliproxyauth.Auth) bool { + if auth == nil { + return false + } + if len(auth.Attributes) > 0 { + if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" { + parsed, errParse := strconv.ParseBool(raw) + if errParse == nil { + return parsed + } + } + } + if len(auth.Metadata) == 0 { + return false + } + raw, ok := auth.Metadata["websockets"] + if !ok || raw == nil { + return false + } + switch v := raw.(type) { + case bool: + return v + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(v)) + if errParse == nil { + return parsed + } + default: + } + return false +} diff --git a/internal/runtime/executor/xai_websockets_executor_test.go b/internal/runtime/executor/xai_websockets_executor_test.go new file mode 100644 index 000000000..68ef26956 --- /dev/null +++ b/internal/runtime/executor/xai_websockets_executor_test.go @@ -0,0 +1,425 @@ +package executor + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestXAIWebsocketsExecuteStreamSendsResponseCreateWithPreviousResponseID(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + capturedPayload := make(chan []byte, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + t.Errorf("path = %q, want /responses", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer xai-token" { + t.Errorf("Authorization = %q, want Bearer xai-token", got) + } + if got := r.Header.Get("x-grok-conv-id"); got != "execution-session-1" { + t.Errorf("x-grok-conv-id = %q, want execution-session-1", got) + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket: %v", err) + return + } + defer func() { _ = conn.Close() }() + + _, payload, errRead := conn.ReadMessage() + if errRead != nil { + t.Errorf("read upstream websocket message: %v", errRead) + return + } + capturedPayload <- bytes.Clone(payload) + completed := []byte(`{"type":"response.completed","response":{"id":"resp-xai-1","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil { + t.Errorf("write completed websocket message: %v", errWrite) + } + })) + defer server.Close() + + exec := NewXAIWebsocketsExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "xai-auth", + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "websockets": "true", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + req := cliproxyexecutor.Request{ + Model: "grok-4.3", + Payload: []byte(`{"model":"grok-4.3","stream":true,"previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`), + } + opts := cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + ResponseFormat: sdktranslator.FormatOpenAIResponse, + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "execution-session-1", + }, + } + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + + result, err := exec.ExecuteStream(ctx, auth, req, opts) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + select { + case payload := <-capturedPayload: + if got := gjson.GetBytes(payload, "type").String(); got != "response.create" { + t.Fatalf("type = %q, want response.create; payload=%s", got, payload) + } + if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-prev" { + t.Fatalf("previous_response_id = %q, want resp-prev; payload=%s", got, payload) + } + if gjson.GetBytes(payload, "stream").Exists() { + t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload) + } + if gjson.GetBytes(payload, "instructions").Exists() { + t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload) + } + if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "execution-session-1" { + t.Fatalf("prompt_cache_key = %q, want execution-session-1; payload=%s", got, payload) + } + if got := gjson.GetBytes(payload, "store").Bool(); !got { + t.Fatalf("store = false, want true; payload=%s", payload) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for upstream websocket payload") + } + + select { + case chunk, ok := <-result.Chunks: + if !ok { + t.Fatal("stream closed before completed chunk") + } + if chunk.Err != nil { + t.Fatalf("chunk error = %v", chunk.Err) + } + if got := gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String(); got != "response.completed" { + t.Fatalf("chunk type = %q, want response.completed; payload=%s", got, chunk.Payload) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for completed chunk") + } +} + +func TestXAIWebsocketsExecuteStreamRewritesRepeatedResponseIDForDownstream(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + capturedPreviousIDs := make(chan string, 3) + releaseServer := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket: %v", err) + return + } + defer func() { _ = conn.Close() }() + + for i := 0; i < 3; i++ { + _, payload, errRead := conn.ReadMessage() + if errRead != nil { + t.Errorf("read upstream websocket message: %v", errRead) + return + } + previousID := gjson.GetBytes(payload, "previous_response_id").String() + capturedPreviousIDs <- previousID + completed := []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-real","previous_response_id":%q,"output":[{"id":"rs_resp-real","type":"reasoning","status":"completed"}],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`, previousID)) + if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil { + t.Errorf("write completed websocket message: %v", errWrite) + return + } + } + <-releaseServer + })) + defer server.Close() + defer close(releaseServer) + + exec := NewXAIWebsocketsExecutor(&config.Config{}) + exec.store = &codexWebsocketSessionStore{sessions: make(map[string]*codexWebsocketSession)} + exec.idStore = &xaiWebsocketIDStateStore{sessions: make(map[string]*xaiWebsocketIDState)} + auth := &cliproxyauth.Auth{ + ID: "xai-auth-id-map", + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "websockets": "true", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + opts := cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + ResponseFormat: sdktranslator.FormatOpenAIResponse, + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "xai-id-map-session", + }, + } + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + + runRequest := func(previousID string) (string, string, string) { + body := []byte(`{"model":"grok-4.3","input":[{"type":"message","role":"user","content":"hello"}]}`) + if previousID != "" { + body = []byte(fmt.Sprintf(`{"model":"grok-4.3","previous_response_id":%q,"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`, previousID)) + } + result, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{Model: "grok-4.3", Payload: body}, opts) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + select { + case chunk, ok := <-result.Chunks: + if !ok { + t.Fatal("stream closed before completed chunk") + } + if chunk.Err != nil { + t.Fatalf("chunk error = %v", chunk.Err) + } + payload := bytes.TrimSpace(chunk.Payload) + return gjson.GetBytes(payload, "response.id").String(), + gjson.GetBytes(payload, "response.output.0.id").String(), + gjson.GetBytes(payload, "response.previous_response_id").String() + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for completed chunk") + } + return "", "", "" + } + + firstDownstreamID, firstOutputID, firstResponsePrevious := runRequest("") + if firstDownstreamID != "resp-real" { + t.Fatalf("first downstream id = %q, want resp-real", firstDownstreamID) + } + if firstOutputID != "rs_resp-real" { + t.Fatalf("first output item id = %q, want rs_resp-real", firstOutputID) + } + if firstResponsePrevious != "" { + t.Fatalf("first response previous_response_id = %q, want empty", firstResponsePrevious) + } + firstUpstreamPrevious := <-capturedPreviousIDs + if firstUpstreamPrevious != "" { + t.Fatalf("first upstream previous_response_id = %q, want empty", firstUpstreamPrevious) + } + + secondDownstreamID, secondOutputID, secondResponsePrevious := runRequest(firstDownstreamID) + if secondDownstreamID == "" || secondDownstreamID == "resp-real" { + t.Fatalf("second downstream id = %q, want synthetic id different from resp-real", secondDownstreamID) + } + if secondOutputID == "rs_resp-real" || !strings.Contains(secondOutputID, secondDownstreamID) { + t.Fatalf("second output item id = %q, want rewritten id containing %q", secondOutputID, secondDownstreamID) + } + if secondResponsePrevious != firstDownstreamID { + t.Fatalf("second response previous_response_id = %q, want %q", secondResponsePrevious, firstDownstreamID) + } + secondUpstreamPrevious := <-capturedPreviousIDs + if secondUpstreamPrevious != "resp-real" { + t.Fatalf("second upstream previous_response_id = %q, want resp-real", secondUpstreamPrevious) + } + + thirdDownstreamID, thirdOutputID, thirdResponsePrevious := runRequest(secondDownstreamID) + if thirdDownstreamID == "" || thirdDownstreamID == "resp-real" || thirdDownstreamID == secondDownstreamID { + t.Fatalf("third downstream id = %q, want a new synthetic id", thirdDownstreamID) + } + if thirdOutputID == "rs_resp-real" || !strings.Contains(thirdOutputID, thirdDownstreamID) { + t.Fatalf("third output item id = %q, want rewritten id containing %q", thirdOutputID, thirdDownstreamID) + } + if thirdResponsePrevious != secondDownstreamID { + t.Fatalf("third response previous_response_id = %q, want %q", thirdResponsePrevious, secondDownstreamID) + } + thirdUpstreamPrevious := <-capturedPreviousIDs + if thirdUpstreamPrevious != "resp-real" { + t.Fatalf("third upstream previous_response_id = %q, want resp-real", thirdUpstreamPrevious) + } +} + +func TestBuildXAIWebsocketRequestBodySetsStoreAndKeepsPromptCacheKey(t *testing.T) { + body := []byte(`{"model":"grok-4.3","stream":true,"stream_options":{"include_usage":true},"background":true,"prompt_cache_key":"cache-1","previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`) + + payload := buildXAIWebsocketRequestBody(body) + + if got := gjson.GetBytes(payload, "type").String(); got != "response.create" { + t.Fatalf("type = %q, want response.create; payload=%s", got, payload) + } + if gjson.GetBytes(payload, "stream").Exists() { + t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload) + } + if gjson.GetBytes(payload, "stream_options").Exists() { + t.Fatalf("stream_options must be omitted for xAI websocket payload: %s", payload) + } + if gjson.GetBytes(payload, "background").Exists() { + t.Fatalf("background must be omitted for xAI websocket payload: %s", payload) + } + if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "cache-1" { + t.Fatalf("prompt_cache_key = %q, want cache-1; payload=%s", got, payload) + } + if got := gjson.GetBytes(payload, "store").Bool(); !got { + t.Fatalf("store = false, want true; payload=%s", payload) + } + if gjson.GetBytes(payload, "instructions").Exists() { + t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload) + } +} + +func TestXAIWebsocketsExecuteStreamCompletesGenerateFalseWarmup(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + capturedPayload := make(chan []byte, 1) + releaseServer := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket: %v", err) + return + } + defer func() { _ = conn.Close() }() + + _, payload, errRead := conn.ReadMessage() + if errRead != nil { + t.Errorf("read upstream websocket message: %v", errRead) + return + } + capturedPayload <- bytes.Clone(payload) + created := []byte(`{"type":"response.created","response":{"id":"resp-warmup-1","object":"response","status":"in_progress","output":[]}}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, created); errWrite != nil { + t.Errorf("write created websocket message: %v", errWrite) + return + } + <-releaseServer + })) + defer server.Close() + defer close(releaseServer) + + exec := NewXAIWebsocketsExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "xai-auth-warmup", + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "websockets": "true", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + req := cliproxyexecutor.Request{ + Model: "grok-4.3", + Payload: []byte(`{"model":"grok-4.3","generate":false,"input":[{"type":"message","role":"user","content":"warm up"}]}`), + } + opts := cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + ResponseFormat: sdktranslator.FormatOpenAIResponse, + } + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + + result, err := exec.ExecuteStream(ctx, auth, req, opts) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + select { + case payload := <-capturedPayload: + if got := gjson.GetBytes(payload, "generate").Bool(); got { + t.Fatalf("generate = true, want false; payload=%s", payload) + } + if got := gjson.GetBytes(payload, "type").String(); got != "response.create" { + t.Fatalf("type = %q, want response.create; payload=%s", got, payload) + } + if got := gjson.GetBytes(payload, "store").Bool(); !got { + t.Fatalf("store = false, want true; payload=%s", payload) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for upstream websocket payload") + } + + var gotTypes []string + for { + select { + case chunk, ok := <-result.Chunks: + if !ok { + if len(gotTypes) != 2 { + t.Fatalf("event types = %v, want response.created and response.completed", gotTypes) + } + return + } + if chunk.Err != nil { + t.Fatalf("chunk error = %v", chunk.Err) + } + gotTypes = append(gotTypes, gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String()) + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for warmup stream to close; event types so far: %v", gotTypes) + } + } +} + +func TestXAIWebsocketsExecuteStreamStopsOnBareErrorPayload(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + releaseServer := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket: %v", err) + return + } + defer func() { _ = conn.Close() }() + + if _, _, errRead := conn.ReadMessage(); errRead != nil { + t.Errorf("read upstream websocket message: %v", errRead) + return + } + payload := []byte(`{"error":{"message":"Request validation error: {\"code\":\"400\",\"error\":\"Argument not supported: instructions and previous_response_id together\"}","type":"api_error"}}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, payload); errWrite != nil { + t.Errorf("write error websocket message: %v", errWrite) + return + } + <-releaseServer + })) + defer server.Close() + defer close(releaseServer) + + exec := NewXAIWebsocketsExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "xai-auth-error", + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "websockets": "true", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + req := cliproxyexecutor.Request{ + Model: "grok-4.3", + Payload: []byte(`{"model":"grok-4.3","input":"hello"}`), + } + opts := cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + ResponseFormat: sdktranslator.FormatOpenAIResponse, + } + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + + result, err := exec.ExecuteStream(ctx, auth, req, opts) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + + select { + case chunk, ok := <-result.Chunks: + if !ok { + t.Fatal("stream closed before error chunk") + } + if chunk.Err == nil { + t.Fatalf("chunk error = nil, want upstream error; payload=%s", chunk.Payload) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for bare upstream error") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 318d5dc14..0bf9eb5a9 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -228,9 +228,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { defer close(wsDone) if h != nil && h.AuthManager != nil { - if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil { - type upstreamDisconnectSubscriber interface { - UpstreamDisconnectChan(sessionID string) <-chan error + type upstreamDisconnectSubscriber interface { + UpstreamDisconnectChan(sessionID string) <-chan error + } + for _, provider := range []string{"codex", "xai"} { + exec, ok := h.AuthManager.Executor(provider) + if !ok || exec == nil { + continue } if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil { disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID) @@ -315,13 +319,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if requestModelName == "" { requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) } - useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName) + useUpstreamWebsocketPassthrough := h.responsesWebsocketUsesUpstreamWebsocketPassthrough(requestModelName) allowIncrementalInputWithPreviousResponseID := false allowCompactionReplayBypass := false - if !useCodexWebsocketPassthrough { + if !useUpstreamWebsocketPassthrough { if pinnedAuthID != "" { if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil { - allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + allowIncrementalInputWithPreviousResponseID = responsesWebsocketAuthSupportsIncrementalInput(pinnedAuth) allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth) } } else { @@ -336,7 +340,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var requestJSON []byte var updatedLastRequest []byte var errMsg *interfaces.ErrorMessage - if useCodexWebsocketPassthrough { + if useUpstreamWebsocketPassthrough { requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName) } else { requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState( @@ -371,7 +375,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { + if !useUpstreamWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { requestJSON = updated } @@ -394,7 +398,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { previousLastResponseID := lastResponseID previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...) forcedTranscriptReplay := forceTranscriptReplayNextRequest - if useCodexWebsocketPassthrough { + if useUpstreamWebsocketPassthrough { if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" { passthroughModelName = modelName } @@ -443,7 +447,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) { pinnedAuthID = "" forceTranscriptReplayNextRequest = true - if useCodexWebsocketPassthrough { + if useUpstreamWebsocketPassthrough { passthroughModelName = "" } else { lastRequest = previousLastRequest @@ -453,7 +457,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - if !useCodexWebsocketPassthrough { + if !useUpstreamWebsocketPassthrough { lastResponseOutput = completedOutput lastResponseID = strings.TrimSpace(completedResponseID) lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...) @@ -917,7 +921,7 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool { auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName) for _, auth := range auths { - if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) { + if responsesWebsocketAuthSupportsIncrementalInput(auth) { return true } } @@ -961,29 +965,47 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod } func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool { + return h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName) +} + +func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName string) bool { modelName = strings.TrimSpace(modelName) if h == nil || h.AuthManager == nil || modelName == "" { return false } - if _, ok := h.AuthManager.Executor("codex"); !ok { - return false - } auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName) if len(auths) == 0 { return false } + provider := "" for _, auth := range auths { if auth == nil { return false } - if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + authProvider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if authProvider != "codex" && authProvider != "xai" { + return false + } + if provider == "" { + provider = authProvider + if _, ok := h.AuthManager.Executor(provider); !ok { + return false + } + } else if authProvider != provider { return false } if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) { return false } } - return true + return provider != "" +} + +func responsesWebsocketAuthSupportsIncrementalInput(auth *coreauth.Auth) bool { + if auth == nil { + return false + } + return websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) } func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 99f4e555f..ad66cf089 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -29,6 +29,11 @@ type websocketCaptureExecutor struct { payloads [][]byte } +type websocketProviderCaptureExecutor struct { + provider string + websocketCaptureExecutor +} + type websocketCompactionCaptureExecutor struct { mu sync.Mutex streamPayloads [][]byte @@ -85,6 +90,7 @@ type websocketBootstrapFallbackExecutor struct { type websocketDirectCaptureExecutor struct { mu sync.Mutex + provider string authIDs []string payloads [][]byte done chan struct{} @@ -164,7 +170,12 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte { return out } -func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" } +func (e *websocketDirectCaptureExecutor) Identifier() string { + if e != nil && strings.TrimSpace(e.provider) != "" { + return strings.TrimSpace(e.provider) + } + return "codex" +} func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { return coreexecutor.Response{}, errors.New("not implemented") @@ -403,6 +414,13 @@ func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte { func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } +func (e *websocketProviderCaptureExecutor) Identifier() string { + if e != nil && strings.TrimSpace(e.provider) != "" { + return strings.TrimSpace(e.provider) + } + return "test-provider" +} + func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { return coreexecutor.Response{}, errors.New("not implemented") } @@ -1641,6 +1659,94 @@ func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithou } } +func TestResponsesWebsocketXAIWebsocketPassthroughCarriesPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + modelName := "xai-websocket-passthrough-model" + executor := &websocketDirectCaptureExecutor{provider: "xai", done: make(chan struct{})} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ + ID: "auth-xai-ws", + Provider: "xai", + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { _ = conn.Close() }() + + firstRequest := []byte(fmt.Sprintf(`{"type":"response.create","model":%q,"input":[{"type":"message","id":"msg-1","role":"user","content":"first"}]}`, modelName)) + if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil { + t.Fatalf("write first websocket message: %v", errWrite) + } + if _, _, errRead := conn.ReadMessage(); errRead != nil { + t.Fatalf("read first websocket response: %v", errRead) + } + + secondRequest := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-2","role":"user","content":"second"}]}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, secondRequest); errWrite != nil { + t.Fatalf("write second websocket message: %v", errWrite) + } + if _, _, errRead := conn.ReadMessage(); errRead != nil { + t.Fatalf("read second websocket response: %v", errRead) + } + + select { + case <-executor.done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for websocket passthrough") + } + + payloads := executor.Payloads() + if len(payloads) != 2 { + t.Fatalf("xai websocket payload count = %d, want 2", len(payloads)) + } + secondPayload := payloads[1] + if got := gjson.GetBytes(secondPayload, "type").String(); got != wsRequestTypeCreate { + t.Fatalf("second xai passthrough type = %s, want %s: %s", got, wsRequestTypeCreate, secondPayload) + } + if got := gjson.GetBytes(secondPayload, "model").String(); got != modelName { + t.Fatalf("second xai payload model = %s, want %s", got, modelName) + } + if got := gjson.GetBytes(secondPayload, "previous_response_id").String(); got != "resp-1" { + t.Fatalf("second xai previous_response_id = %s, want resp-1: %s", got, secondPayload) + } + input := gjson.GetBytes(secondPayload, "input").Array() + if len(input) != 1 { + t.Fatalf("second xai passthrough input len = %d, want 1: %s", len(input), secondPayload) + } + if input[0].Get("id").String() != "msg-2" { + t.Fatalf("second xai passthrough input must contain only the new turn: %s", secondPayload) + } + if bytes.Contains(secondPayload, []byte(`"id":"msg-1"`)) || bytes.Contains(secondPayload, []byte(`"id":"out-1"`)) { + t.Fatalf("second xai passthrough payload contains stale transcript state: %s", secondPayload) + } + authIDs := executor.AuthIDs() + if len(authIDs) != 2 || authIDs[0] != "auth-xai-ws" || authIDs[1] != "auth-xai-ws" { + t.Fatalf("xai websocket auth IDs = %v, want [auth-xai-ws auth-xai-ws]", authIDs) + } +} + func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { manager := coreauth.NewManager(nil, nil, nil) auth := &coreauth.Auth{ @@ -1664,6 +1770,56 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { } } +func TestWebsocketUpstreamSupportsIncrementalInputForXAI(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + auth := &coreauth.Auth{ + ID: "auth-xai-ws", + Provider: "xai", + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "xai-test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + if !h.websocketUpstreamSupportsIncrementalInputForModel("xai-test-model") { + t.Fatalf("expected xai websocket upstream to support previous_response_id incremental input") + } +} + +func TestResponsesWebsocketUsesUpstreamWebsocketPassthroughForXAI(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + executor := &websocketProviderCaptureExecutor{provider: "xai"} + manager.RegisterExecutor(executor) + + modelName := "xai-passthrough-model" + auth := &coreauth.Auth{ + ID: "auth-xai-ws", + Provider: "xai", + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + if !h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName) { + t.Fatalf("expected xai websocket upstream passthrough for %s", modelName) + } +} + func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) { manager := coreauth.NewManager(nil, nil, nil) auth := &coreauth.Auth{ diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go index 9f9718d49..b3b61534f 100644 --- a/sdk/cliproxy/auth/scheduler.go +++ b/sdk/cliproxy/auth/scheduler.go @@ -249,7 +249,7 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo providerKey := strings.ToLower(strings.TrimSpace(provider)) modelKey := canonicalModelKey(model) pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) - preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == "" + preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerPrefersWebsocketTransport(providerKey) && pinnedAuthID == "" s.mu.Lock() defer s.mu.Unlock() @@ -284,6 +284,15 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo return nil, shard.unavailableErrorLocked(provider, model, predicate) } +func providerPrefersWebsocketTransport(providerKey string) bool { + switch strings.ToLower(strings.TrimSpace(providerKey)) { + case "codex", "xai": + return true + default: + return false + } +} + // pickMixed returns the next auth and provider for a mixed-provider request. func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) { return s.pickMixedWithStrategy(ctx, providers, model, opts, tried, schedulerStrategyCurrent) diff --git a/sdk/cliproxy/auth/scheduler_test.go b/sdk/cliproxy/auth/scheduler_test.go index 39b6c6fb5..5843eaed3 100644 --- a/sdk/cliproxy/auth/scheduler_test.go +++ b/sdk/cliproxy/auth/scheduler_test.go @@ -237,6 +237,32 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) } } +func TestSchedulerPick_XAIWebsocketPrefersWebsocketEnabledSubset(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "xai-http", Provider: "xai"}, + &Auth{ID: "xai-ws-a", Provider: "xai", Attributes: map[string]string{"websockets": "true"}}, + &Auth{ID: "xai-ws-b", Provider: "xai", Attributes: map[string]string{"websockets": "true"}}, + ) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + want := []string{"xai-ws-a", "xai-ws-b", "xai-ws-a"} + for index, wantID := range want { + got, errPick := scheduler.pickSingle(ctx, "xai", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantID { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID) + } + } +} + func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledAcrossPriorities(t *testing.T) { t.Parallel() diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index bedbffb80..f5abd389c 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -740,6 +740,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { if strings.EqualFold(provider, "codex") { executor.CloseCodexWebsocketSessionsForAuthID(id, "auth_removed") } + if strings.EqualFold(provider, "xai") { + executor.CloseXAIWebsocketSessionsForAuthID(id, "auth_removed") + } s.syncPluginRuntime(ctx) } @@ -948,7 +951,7 @@ func (s *Service) registerExecutorForAuth(a *coreauth.Auth, forceReplace bool) { case "kimi": s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg)) case "xai": - s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg)) + s.coreManager.RegisterExecutor(executor.NewXAIAutoExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { diff --git a/sdk/cliproxy/service_codex_executor_binding_test.go b/sdk/cliproxy/service_codex_executor_binding_test.go index 20a9cd7c8..0cd399ef2 100644 --- a/sdk/cliproxy/service_codex_executor_binding_test.go +++ b/sdk/cliproxy/service_codex_executor_binding_test.go @@ -3,6 +3,7 @@ package cliproxy import ( "testing" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" ) @@ -62,3 +63,25 @@ func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) { t.Fatal("expected codex executor replacement in force mode") } } + +func TestEnsureExecutorsForAuth_XAIBindsAutoExecutor(t *testing.T) { + service := &Service{ + cfg: &config.Config{}, + coreManager: coreauth.NewManager(nil, nil, nil), + } + auth := &coreauth.Auth{ + ID: "xai-auth-1", + Provider: "xai", + Status: coreauth.StatusActive, + } + + service.ensureExecutorsForAuth(auth) + + gotExecutor, ok := service.coreManager.Executor("xai") + if !ok || gotExecutor == nil { + t.Fatal("expected xai executor after bind") + } + if _, ok := gotExecutor.(*executor.XAIAutoExecutor); !ok { + t.Fatalf("xai executor type = %T, want *executor.XAIAutoExecutor", gotExecutor) + } +}