feat(websockets): implement XAIWebsocketsExecutor with enhanced execution and ID mapping

- Developed `XAIWebsocketsExecutor` for handling xAI Responses via WebSocket transport.
- Introduced session and state management with `codexWebsocketSessionStore` and `xaiWebsocketIDStateStore`.
- Added robust ID mapping for upstream and downstream request/response sequences.
- Enhanced error propagation and handling of WebSocket terminal events.
- Included utility methods for WebSocket request preparation, connection management, and state tracking.
- Added foundational support for compact and streamed responses via enhanced session tracking.
This commit is contained in:
Luis Pater
2026-06-15 08:22:07 +08:00
parent 56988aea0f
commit ea90ab6f77
8 changed files with 1925 additions and 20 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,425 @@
package executor
import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
"github.com/tidwall/gjson"
)
func TestXAIWebsocketsExecuteStreamSendsResponseCreateWithPreviousResponseID(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
capturedPayload := make(chan []byte, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
t.Errorf("path = %q, want /responses", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer xai-token" {
t.Errorf("Authorization = %q, want Bearer xai-token", got)
}
if got := r.Header.Get("x-grok-conv-id"); got != "execution-session-1" {
t.Errorf("x-grok-conv-id = %q, want execution-session-1", got)
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
_, payload, errRead := conn.ReadMessage()
if errRead != nil {
t.Errorf("read upstream websocket message: %v", errRead)
return
}
capturedPayload <- bytes.Clone(payload)
completed := []byte(`{"type":"response.completed","response":{"id":"resp-xai-1","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
t.Errorf("write completed websocket message: %v", errWrite)
}
}))
defer server.Close()
exec := NewXAIWebsocketsExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "xai-auth",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"websockets": "true",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
req := cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","stream":true,"previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`),
}
opts := cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
ResponseFormat: sdktranslator.FormatOpenAIResponse,
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "execution-session-1",
},
}
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
result, err := exec.ExecuteStream(ctx, auth, req, opts)
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
select {
case payload := <-capturedPayload:
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-prev" {
t.Fatalf("previous_response_id = %q, want resp-prev; payload=%s", got, payload)
}
if gjson.GetBytes(payload, "stream").Exists() {
t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload)
}
if gjson.GetBytes(payload, "instructions").Exists() {
t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload)
}
if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "execution-session-1" {
t.Fatalf("prompt_cache_key = %q, want execution-session-1; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "store").Bool(); !got {
t.Fatalf("store = false, want true; payload=%s", payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for upstream websocket payload")
}
select {
case chunk, ok := <-result.Chunks:
if !ok {
t.Fatal("stream closed before completed chunk")
}
if chunk.Err != nil {
t.Fatalf("chunk error = %v", chunk.Err)
}
if got := gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String(); got != "response.completed" {
t.Fatalf("chunk type = %q, want response.completed; payload=%s", got, chunk.Payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for completed chunk")
}
}
func TestXAIWebsocketsExecuteStreamRewritesRepeatedResponseIDForDownstream(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
capturedPreviousIDs := make(chan string, 3)
releaseServer := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
for i := 0; i < 3; i++ {
_, payload, errRead := conn.ReadMessage()
if errRead != nil {
t.Errorf("read upstream websocket message: %v", errRead)
return
}
previousID := gjson.GetBytes(payload, "previous_response_id").String()
capturedPreviousIDs <- previousID
completed := []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-real","previous_response_id":%q,"output":[{"id":"rs_resp-real","type":"reasoning","status":"completed"}],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`, previousID))
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
t.Errorf("write completed websocket message: %v", errWrite)
return
}
}
<-releaseServer
}))
defer server.Close()
defer close(releaseServer)
exec := NewXAIWebsocketsExecutor(&config.Config{})
exec.store = &codexWebsocketSessionStore{sessions: make(map[string]*codexWebsocketSession)}
exec.idStore = &xaiWebsocketIDStateStore{sessions: make(map[string]*xaiWebsocketIDState)}
auth := &cliproxyauth.Auth{
ID: "xai-auth-id-map",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"websockets": "true",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
opts := cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
ResponseFormat: sdktranslator.FormatOpenAIResponse,
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "xai-id-map-session",
},
}
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
runRequest := func(previousID string) (string, string, string) {
body := []byte(`{"model":"grok-4.3","input":[{"type":"message","role":"user","content":"hello"}]}`)
if previousID != "" {
body = []byte(fmt.Sprintf(`{"model":"grok-4.3","previous_response_id":%q,"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`, previousID))
}
result, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{Model: "grok-4.3", Payload: body}, opts)
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
select {
case chunk, ok := <-result.Chunks:
if !ok {
t.Fatal("stream closed before completed chunk")
}
if chunk.Err != nil {
t.Fatalf("chunk error = %v", chunk.Err)
}
payload := bytes.TrimSpace(chunk.Payload)
return gjson.GetBytes(payload, "response.id").String(),
gjson.GetBytes(payload, "response.output.0.id").String(),
gjson.GetBytes(payload, "response.previous_response_id").String()
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for completed chunk")
}
return "", "", ""
}
firstDownstreamID, firstOutputID, firstResponsePrevious := runRequest("")
if firstDownstreamID != "resp-real" {
t.Fatalf("first downstream id = %q, want resp-real", firstDownstreamID)
}
if firstOutputID != "rs_resp-real" {
t.Fatalf("first output item id = %q, want rs_resp-real", firstOutputID)
}
if firstResponsePrevious != "" {
t.Fatalf("first response previous_response_id = %q, want empty", firstResponsePrevious)
}
firstUpstreamPrevious := <-capturedPreviousIDs
if firstUpstreamPrevious != "" {
t.Fatalf("first upstream previous_response_id = %q, want empty", firstUpstreamPrevious)
}
secondDownstreamID, secondOutputID, secondResponsePrevious := runRequest(firstDownstreamID)
if secondDownstreamID == "" || secondDownstreamID == "resp-real" {
t.Fatalf("second downstream id = %q, want synthetic id different from resp-real", secondDownstreamID)
}
if secondOutputID == "rs_resp-real" || !strings.Contains(secondOutputID, secondDownstreamID) {
t.Fatalf("second output item id = %q, want rewritten id containing %q", secondOutputID, secondDownstreamID)
}
if secondResponsePrevious != firstDownstreamID {
t.Fatalf("second response previous_response_id = %q, want %q", secondResponsePrevious, firstDownstreamID)
}
secondUpstreamPrevious := <-capturedPreviousIDs
if secondUpstreamPrevious != "resp-real" {
t.Fatalf("second upstream previous_response_id = %q, want resp-real", secondUpstreamPrevious)
}
thirdDownstreamID, thirdOutputID, thirdResponsePrevious := runRequest(secondDownstreamID)
if thirdDownstreamID == "" || thirdDownstreamID == "resp-real" || thirdDownstreamID == secondDownstreamID {
t.Fatalf("third downstream id = %q, want a new synthetic id", thirdDownstreamID)
}
if thirdOutputID == "rs_resp-real" || !strings.Contains(thirdOutputID, thirdDownstreamID) {
t.Fatalf("third output item id = %q, want rewritten id containing %q", thirdOutputID, thirdDownstreamID)
}
if thirdResponsePrevious != secondDownstreamID {
t.Fatalf("third response previous_response_id = %q, want %q", thirdResponsePrevious, secondDownstreamID)
}
thirdUpstreamPrevious := <-capturedPreviousIDs
if thirdUpstreamPrevious != "resp-real" {
t.Fatalf("third upstream previous_response_id = %q, want resp-real", thirdUpstreamPrevious)
}
}
func TestBuildXAIWebsocketRequestBodySetsStoreAndKeepsPromptCacheKey(t *testing.T) {
body := []byte(`{"model":"grok-4.3","stream":true,"stream_options":{"include_usage":true},"background":true,"prompt_cache_key":"cache-1","previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`)
payload := buildXAIWebsocketRequestBody(body)
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
}
if gjson.GetBytes(payload, "stream").Exists() {
t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload)
}
if gjson.GetBytes(payload, "stream_options").Exists() {
t.Fatalf("stream_options must be omitted for xAI websocket payload: %s", payload)
}
if gjson.GetBytes(payload, "background").Exists() {
t.Fatalf("background must be omitted for xAI websocket payload: %s", payload)
}
if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "cache-1" {
t.Fatalf("prompt_cache_key = %q, want cache-1; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "store").Bool(); !got {
t.Fatalf("store = false, want true; payload=%s", payload)
}
if gjson.GetBytes(payload, "instructions").Exists() {
t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload)
}
}
func TestXAIWebsocketsExecuteStreamCompletesGenerateFalseWarmup(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
capturedPayload := make(chan []byte, 1)
releaseServer := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
_, payload, errRead := conn.ReadMessage()
if errRead != nil {
t.Errorf("read upstream websocket message: %v", errRead)
return
}
capturedPayload <- bytes.Clone(payload)
created := []byte(`{"type":"response.created","response":{"id":"resp-warmup-1","object":"response","status":"in_progress","output":[]}}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, created); errWrite != nil {
t.Errorf("write created websocket message: %v", errWrite)
return
}
<-releaseServer
}))
defer server.Close()
defer close(releaseServer)
exec := NewXAIWebsocketsExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "xai-auth-warmup",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"websockets": "true",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
req := cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","generate":false,"input":[{"type":"message","role":"user","content":"warm up"}]}`),
}
opts := cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
ResponseFormat: sdktranslator.FormatOpenAIResponse,
}
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
result, err := exec.ExecuteStream(ctx, auth, req, opts)
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
select {
case payload := <-capturedPayload:
if got := gjson.GetBytes(payload, "generate").Bool(); got {
t.Fatalf("generate = true, want false; payload=%s", payload)
}
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "store").Bool(); !got {
t.Fatalf("store = false, want true; payload=%s", payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for upstream websocket payload")
}
var gotTypes []string
for {
select {
case chunk, ok := <-result.Chunks:
if !ok {
if len(gotTypes) != 2 {
t.Fatalf("event types = %v, want response.created and response.completed", gotTypes)
}
return
}
if chunk.Err != nil {
t.Fatalf("chunk error = %v", chunk.Err)
}
gotTypes = append(gotTypes, gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String())
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for warmup stream to close; event types so far: %v", gotTypes)
}
}
}
func TestXAIWebsocketsExecuteStreamStopsOnBareErrorPayload(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
releaseServer := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
if _, _, errRead := conn.ReadMessage(); errRead != nil {
t.Errorf("read upstream websocket message: %v", errRead)
return
}
payload := []byte(`{"error":{"message":"Request validation error: {\"code\":\"400\",\"error\":\"Argument not supported: instructions and previous_response_id together\"}","type":"api_error"}}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, payload); errWrite != nil {
t.Errorf("write error websocket message: %v", errWrite)
return
}
<-releaseServer
}))
defer server.Close()
defer close(releaseServer)
exec := NewXAIWebsocketsExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "xai-auth-error",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"websockets": "true",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
req := cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","input":"hello"}`),
}
opts := cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
ResponseFormat: sdktranslator.FormatOpenAIResponse,
}
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
result, err := exec.ExecuteStream(ctx, auth, req, opts)
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
select {
case chunk, ok := <-result.Chunks:
if !ok {
t.Fatal("stream closed before error chunk")
}
if chunk.Err == nil {
t.Fatalf("chunk error = nil, want upstream error; payload=%s", chunk.Payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for bare upstream error")
}
}

View File

@@ -228,9 +228,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
defer close(wsDone)
if h != nil && h.AuthManager != nil {
if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil {
type upstreamDisconnectSubscriber interface {
UpstreamDisconnectChan(sessionID string) <-chan error
type upstreamDisconnectSubscriber interface {
UpstreamDisconnectChan(sessionID string) <-chan error
}
for _, provider := range []string{"codex", "xai"} {
exec, ok := h.AuthManager.Executor(provider)
if !ok || exec == nil {
continue
}
if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil {
disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID)
@@ -315,13 +319,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName)
useUpstreamWebsocketPassthrough := h.responsesWebsocketUsesUpstreamWebsocketPassthrough(requestModelName)
allowIncrementalInputWithPreviousResponseID := false
allowCompactionReplayBypass := false
if !useCodexWebsocketPassthrough {
if !useUpstreamWebsocketPassthrough {
if pinnedAuthID != "" {
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
allowIncrementalInputWithPreviousResponseID = responsesWebsocketAuthSupportsIncrementalInput(pinnedAuth)
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
}
} else {
@@ -336,7 +340,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
if useCodexWebsocketPassthrough {
if useUpstreamWebsocketPassthrough {
requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName)
} else {
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
@@ -371,7 +375,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if !useUpstreamWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
@@ -394,7 +398,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
previousLastResponseID := lastResponseID
previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...)
forcedTranscriptReplay := forceTranscriptReplayNextRequest
if useCodexWebsocketPassthrough {
if useUpstreamWebsocketPassthrough {
if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" {
passthroughModelName = modelName
}
@@ -443,7 +447,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
pinnedAuthID = ""
forceTranscriptReplayNextRequest = true
if useCodexWebsocketPassthrough {
if useUpstreamWebsocketPassthrough {
passthroughModelName = ""
} else {
lastRequest = previousLastRequest
@@ -453,7 +457,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if !useCodexWebsocketPassthrough {
if !useUpstreamWebsocketPassthrough {
lastResponseOutput = completedOutput
lastResponseID = strings.TrimSpace(completedResponseID)
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
@@ -917,7 +921,7 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
for _, auth := range auths {
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
if responsesWebsocketAuthSupportsIncrementalInput(auth) {
return true
}
}
@@ -961,29 +965,47 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod
}
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool {
return h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName)
}
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName string) bool {
modelName = strings.TrimSpace(modelName)
if h == nil || h.AuthManager == nil || modelName == "" {
return false
}
if _, ok := h.AuthManager.Executor("codex"); !ok {
return false
}
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
if len(auths) == 0 {
return false
}
provider := ""
for _, auth := range auths {
if auth == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
authProvider := strings.ToLower(strings.TrimSpace(auth.Provider))
if authProvider != "codex" && authProvider != "xai" {
return false
}
if provider == "" {
provider = authProvider
if _, ok := h.AuthManager.Executor(provider); !ok {
return false
}
} else if authProvider != provider {
return false
}
if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return false
}
}
return true
return provider != ""
}
func responsesWebsocketAuthSupportsIncrementalInput(auth *coreauth.Auth) bool {
if auth == nil {
return false
}
return websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata)
}
func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) {

View File

@@ -29,6 +29,11 @@ type websocketCaptureExecutor struct {
payloads [][]byte
}
type websocketProviderCaptureExecutor struct {
provider string
websocketCaptureExecutor
}
type websocketCompactionCaptureExecutor struct {
mu sync.Mutex
streamPayloads [][]byte
@@ -85,6 +90,7 @@ type websocketBootstrapFallbackExecutor struct {
type websocketDirectCaptureExecutor struct {
mu sync.Mutex
provider string
authIDs []string
payloads [][]byte
done chan struct{}
@@ -164,7 +170,12 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
return out
}
func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" }
func (e *websocketDirectCaptureExecutor) Identifier() string {
if e != nil && strings.TrimSpace(e.provider) != "" {
return strings.TrimSpace(e.provider)
}
return "codex"
}
func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
@@ -403,6 +414,13 @@ func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte {
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketProviderCaptureExecutor) Identifier() string {
if e != nil && strings.TrimSpace(e.provider) != "" {
return strings.TrimSpace(e.provider)
}
return "test-provider"
}
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
@@ -1641,6 +1659,94 @@ func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithou
}
}
func TestResponsesWebsocketXAIWebsocketPassthroughCarriesPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
modelName := "xai-websocket-passthrough-model"
executor := &websocketDirectCaptureExecutor{provider: "xai", done: make(chan struct{})}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{
ID: "auth-xai-ws",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() { _ = conn.Close() }()
firstRequest := []byte(fmt.Sprintf(`{"type":"response.create","model":%q,"input":[{"type":"message","id":"msg-1","role":"user","content":"first"}]}`, modelName))
if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil {
t.Fatalf("write first websocket message: %v", errWrite)
}
if _, _, errRead := conn.ReadMessage(); errRead != nil {
t.Fatalf("read first websocket response: %v", errRead)
}
secondRequest := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-2","role":"user","content":"second"}]}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, secondRequest); errWrite != nil {
t.Fatalf("write second websocket message: %v", errWrite)
}
if _, _, errRead := conn.ReadMessage(); errRead != nil {
t.Fatalf("read second websocket response: %v", errRead)
}
select {
case <-executor.done:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for websocket passthrough")
}
payloads := executor.Payloads()
if len(payloads) != 2 {
t.Fatalf("xai websocket payload count = %d, want 2", len(payloads))
}
secondPayload := payloads[1]
if got := gjson.GetBytes(secondPayload, "type").String(); got != wsRequestTypeCreate {
t.Fatalf("second xai passthrough type = %s, want %s: %s", got, wsRequestTypeCreate, secondPayload)
}
if got := gjson.GetBytes(secondPayload, "model").String(); got != modelName {
t.Fatalf("second xai payload model = %s, want %s", got, modelName)
}
if got := gjson.GetBytes(secondPayload, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("second xai previous_response_id = %s, want resp-1: %s", got, secondPayload)
}
input := gjson.GetBytes(secondPayload, "input").Array()
if len(input) != 1 {
t.Fatalf("second xai passthrough input len = %d, want 1: %s", len(input), secondPayload)
}
if input[0].Get("id").String() != "msg-2" {
t.Fatalf("second xai passthrough input must contain only the new turn: %s", secondPayload)
}
if bytes.Contains(secondPayload, []byte(`"id":"msg-1"`)) || bytes.Contains(secondPayload, []byte(`"id":"out-1"`)) {
t.Fatalf("second xai passthrough payload contains stale transcript state: %s", secondPayload)
}
authIDs := executor.AuthIDs()
if len(authIDs) != 2 || authIDs[0] != "auth-xai-ws" || authIDs[1] != "auth-xai-ws" {
t.Fatalf("xai websocket auth IDs = %v, want [auth-xai-ws auth-xai-ws]", authIDs)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
@@ -1664,6 +1770,56 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForXAI(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-xai-ws",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "xai-test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsIncrementalInputForModel("xai-test-model") {
t.Fatalf("expected xai websocket upstream to support previous_response_id incremental input")
}
}
func TestResponsesWebsocketUsesUpstreamWebsocketPassthroughForXAI(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
executor := &websocketProviderCaptureExecutor{provider: "xai"}
manager.RegisterExecutor(executor)
modelName := "xai-passthrough-model"
auth := &coreauth.Auth{
ID: "auth-xai-ws",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName) {
t.Fatalf("expected xai websocket upstream passthrough for %s", modelName)
}
}
func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{

View File

@@ -249,7 +249,7 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerPrefersWebsocketTransport(providerKey) && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
@@ -284,6 +284,15 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
func providerPrefersWebsocketTransport(providerKey string) bool {
switch strings.ToLower(strings.TrimSpace(providerKey)) {
case "codex", "xai":
return true
default:
return false
}
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
return s.pickMixedWithStrategy(ctx, providers, model, opts, tried, schedulerStrategyCurrent)

View File

@@ -237,6 +237,32 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
}
}
func TestSchedulerPick_XAIWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "xai-http", Provider: "xai"},
&Auth{ID: "xai-ws-a", Provider: "xai", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "xai-ws-b", Provider: "xai", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"xai-ws-a", "xai-ws-b", "xai-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "xai", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledAcrossPriorities(t *testing.T) {
t.Parallel()

View File

@@ -740,6 +740,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
if strings.EqualFold(provider, "codex") {
executor.CloseCodexWebsocketSessionsForAuthID(id, "auth_removed")
}
if strings.EqualFold(provider, "xai") {
executor.CloseXAIWebsocketSessionsForAuthID(id, "auth_removed")
}
s.syncPluginRuntime(ctx)
}
@@ -948,7 +951,7 @@ func (s *Service) registerExecutorForAuth(a *coreauth.Auth, forceReplace bool) {
case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
case "xai":
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
s.coreManager.RegisterExecutor(executor.NewXAIAutoExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {

View File

@@ -3,6 +3,7 @@ package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
@@ -62,3 +63,25 @@ func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
t.Fatal("expected codex executor replacement in force mode")
}
}
func TestEnsureExecutorsForAuth_XAIBindsAutoExecutor(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "xai-auth-1",
Provider: "xai",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
gotExecutor, ok := service.coreManager.Executor("xai")
if !ok || gotExecutor == nil {
t.Fatal("expected xai executor after bind")
}
if _, ok := gotExecutor.(*executor.XAIAutoExecutor); !ok {
t.Fatalf("xai executor type = %T, want *executor.XAIAutoExecutor", gotExecutor)
}
}