mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-24 16:27:43 +08:00
feat(websockets): implement XAIWebsocketsExecutor with enhanced execution and ID mapping
- Developed `XAIWebsocketsExecutor` for handling xAI Responses via WebSocket transport. - Introduced session and state management with `codexWebsocketSessionStore` and `xaiWebsocketIDStateStore`. - Added robust ID mapping for upstream and downstream request/response sequences. - Enhanced error propagation and handling of WebSocket terminal events. - Included utility methods for WebSocket request preparation, connection management, and state tracking. - Added foundational support for compact and streamed responses via enhanced session tracking.
This commit is contained in:
1241
internal/runtime/executor/xai_websockets_executor.go
Normal file
1241
internal/runtime/executor/xai_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
425
internal/runtime/executor/xai_websockets_executor_test.go
Normal file
425
internal/runtime/executor/xai_websockets_executor_test.go
Normal file
@@ -0,0 +1,425 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestXAIWebsocketsExecuteStreamSendsResponseCreateWithPreviousResponseID(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
capturedPayload := make(chan []byte, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
t.Errorf("path = %q, want /responses", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer xai-token" {
|
||||
t.Errorf("Authorization = %q, want Bearer xai-token", got)
|
||||
}
|
||||
if got := r.Header.Get("x-grok-conv-id"); got != "execution-session-1" {
|
||||
t.Errorf("x-grok-conv-id = %q, want execution-session-1", got)
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
_, payload, errRead := conn.ReadMessage()
|
||||
if errRead != nil {
|
||||
t.Errorf("read upstream websocket message: %v", errRead)
|
||||
return
|
||||
}
|
||||
capturedPayload <- bytes.Clone(payload)
|
||||
completed := []byte(`{"type":"response.completed","response":{"id":"resp-xai-1","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
|
||||
t.Errorf("write completed websocket message: %v", errWrite)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIWebsocketsExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"websockets": "true",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "grok-4.3",
|
||||
Payload: []byte(`{"model":"grok-4.3","stream":true,"previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
ResponseFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "execution-session-1",
|
||||
},
|
||||
}
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
|
||||
result, err := exec.ExecuteStream(ctx, auth, req, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case payload := <-capturedPayload:
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-prev" {
|
||||
t.Fatalf("previous_response_id = %q, want resp-prev; payload=%s", got, payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "stream").Exists() {
|
||||
t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "instructions").Exists() {
|
||||
t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "execution-session-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want execution-session-1; payload=%s", got, payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "store").Bool(); !got {
|
||||
t.Fatalf("store = false, want true; payload=%s", payload)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for upstream websocket payload")
|
||||
}
|
||||
|
||||
select {
|
||||
case chunk, ok := <-result.Chunks:
|
||||
if !ok {
|
||||
t.Fatal("stream closed before completed chunk")
|
||||
}
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error = %v", chunk.Err)
|
||||
}
|
||||
if got := gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String(); got != "response.completed" {
|
||||
t.Fatalf("chunk type = %q, want response.completed; payload=%s", got, chunk.Payload)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for completed chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIWebsocketsExecuteStreamRewritesRepeatedResponseIDForDownstream(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
capturedPreviousIDs := make(chan string, 3)
|
||||
releaseServer := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
_, payload, errRead := conn.ReadMessage()
|
||||
if errRead != nil {
|
||||
t.Errorf("read upstream websocket message: %v", errRead)
|
||||
return
|
||||
}
|
||||
previousID := gjson.GetBytes(payload, "previous_response_id").String()
|
||||
capturedPreviousIDs <- previousID
|
||||
completed := []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-real","previous_response_id":%q,"output":[{"id":"rs_resp-real","type":"reasoning","status":"completed"}],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`, previousID))
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
|
||||
t.Errorf("write completed websocket message: %v", errWrite)
|
||||
return
|
||||
}
|
||||
}
|
||||
<-releaseServer
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(releaseServer)
|
||||
|
||||
exec := NewXAIWebsocketsExecutor(&config.Config{})
|
||||
exec.store = &codexWebsocketSessionStore{sessions: make(map[string]*codexWebsocketSession)}
|
||||
exec.idStore = &xaiWebsocketIDStateStore{sessions: make(map[string]*xaiWebsocketIDState)}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth-id-map",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"websockets": "true",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
opts := cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
ResponseFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "xai-id-map-session",
|
||||
},
|
||||
}
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
|
||||
runRequest := func(previousID string) (string, string, string) {
|
||||
body := []byte(`{"model":"grok-4.3","input":[{"type":"message","role":"user","content":"hello"}]}`)
|
||||
if previousID != "" {
|
||||
body = []byte(fmt.Sprintf(`{"model":"grok-4.3","previous_response_id":%q,"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`, previousID))
|
||||
}
|
||||
result, err := exec.ExecuteStream(ctx, auth, cliproxyexecutor.Request{Model: "grok-4.3", Payload: body}, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
select {
|
||||
case chunk, ok := <-result.Chunks:
|
||||
if !ok {
|
||||
t.Fatal("stream closed before completed chunk")
|
||||
}
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error = %v", chunk.Err)
|
||||
}
|
||||
payload := bytes.TrimSpace(chunk.Payload)
|
||||
return gjson.GetBytes(payload, "response.id").String(),
|
||||
gjson.GetBytes(payload, "response.output.0.id").String(),
|
||||
gjson.GetBytes(payload, "response.previous_response_id").String()
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for completed chunk")
|
||||
}
|
||||
return "", "", ""
|
||||
}
|
||||
|
||||
firstDownstreamID, firstOutputID, firstResponsePrevious := runRequest("")
|
||||
if firstDownstreamID != "resp-real" {
|
||||
t.Fatalf("first downstream id = %q, want resp-real", firstDownstreamID)
|
||||
}
|
||||
if firstOutputID != "rs_resp-real" {
|
||||
t.Fatalf("first output item id = %q, want rs_resp-real", firstOutputID)
|
||||
}
|
||||
if firstResponsePrevious != "" {
|
||||
t.Fatalf("first response previous_response_id = %q, want empty", firstResponsePrevious)
|
||||
}
|
||||
firstUpstreamPrevious := <-capturedPreviousIDs
|
||||
if firstUpstreamPrevious != "" {
|
||||
t.Fatalf("first upstream previous_response_id = %q, want empty", firstUpstreamPrevious)
|
||||
}
|
||||
|
||||
secondDownstreamID, secondOutputID, secondResponsePrevious := runRequest(firstDownstreamID)
|
||||
if secondDownstreamID == "" || secondDownstreamID == "resp-real" {
|
||||
t.Fatalf("second downstream id = %q, want synthetic id different from resp-real", secondDownstreamID)
|
||||
}
|
||||
if secondOutputID == "rs_resp-real" || !strings.Contains(secondOutputID, secondDownstreamID) {
|
||||
t.Fatalf("second output item id = %q, want rewritten id containing %q", secondOutputID, secondDownstreamID)
|
||||
}
|
||||
if secondResponsePrevious != firstDownstreamID {
|
||||
t.Fatalf("second response previous_response_id = %q, want %q", secondResponsePrevious, firstDownstreamID)
|
||||
}
|
||||
secondUpstreamPrevious := <-capturedPreviousIDs
|
||||
if secondUpstreamPrevious != "resp-real" {
|
||||
t.Fatalf("second upstream previous_response_id = %q, want resp-real", secondUpstreamPrevious)
|
||||
}
|
||||
|
||||
thirdDownstreamID, thirdOutputID, thirdResponsePrevious := runRequest(secondDownstreamID)
|
||||
if thirdDownstreamID == "" || thirdDownstreamID == "resp-real" || thirdDownstreamID == secondDownstreamID {
|
||||
t.Fatalf("third downstream id = %q, want a new synthetic id", thirdDownstreamID)
|
||||
}
|
||||
if thirdOutputID == "rs_resp-real" || !strings.Contains(thirdOutputID, thirdDownstreamID) {
|
||||
t.Fatalf("third output item id = %q, want rewritten id containing %q", thirdOutputID, thirdDownstreamID)
|
||||
}
|
||||
if thirdResponsePrevious != secondDownstreamID {
|
||||
t.Fatalf("third response previous_response_id = %q, want %q", thirdResponsePrevious, secondDownstreamID)
|
||||
}
|
||||
thirdUpstreamPrevious := <-capturedPreviousIDs
|
||||
if thirdUpstreamPrevious != "resp-real" {
|
||||
t.Fatalf("third upstream previous_response_id = %q, want resp-real", thirdUpstreamPrevious)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIWebsocketRequestBodySetsStoreAndKeepsPromptCacheKey(t *testing.T) {
|
||||
body := []byte(`{"model":"grok-4.3","stream":true,"stream_options":{"include_usage":true},"background":true,"prompt_cache_key":"cache-1","previous_response_id":"resp-prev","instructions":"system prompt","input":[{"type":"message","role":"user","content":"hello"}]}`)
|
||||
|
||||
payload := buildXAIWebsocketRequestBody(body)
|
||||
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "stream").Exists() {
|
||||
t.Fatalf("stream must be omitted for xAI websocket payload: %s", payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "stream_options").Exists() {
|
||||
t.Fatalf("stream_options must be omitted for xAI websocket payload: %s", payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "background").Exists() {
|
||||
t.Fatalf("background must be omitted for xAI websocket payload: %s", payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "prompt_cache_key").String(); got != "cache-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want cache-1; payload=%s", got, payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "store").Bool(); !got {
|
||||
t.Fatalf("store = false, want true; payload=%s", payload)
|
||||
}
|
||||
if gjson.GetBytes(payload, "instructions").Exists() {
|
||||
t.Fatalf("instructions must be omitted when previous_response_id is set: %s", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIWebsocketsExecuteStreamCompletesGenerateFalseWarmup(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
capturedPayload := make(chan []byte, 1)
|
||||
releaseServer := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
_, payload, errRead := conn.ReadMessage()
|
||||
if errRead != nil {
|
||||
t.Errorf("read upstream websocket message: %v", errRead)
|
||||
return
|
||||
}
|
||||
capturedPayload <- bytes.Clone(payload)
|
||||
created := []byte(`{"type":"response.created","response":{"id":"resp-warmup-1","object":"response","status":"in_progress","output":[]}}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, created); errWrite != nil {
|
||||
t.Errorf("write created websocket message: %v", errWrite)
|
||||
return
|
||||
}
|
||||
<-releaseServer
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(releaseServer)
|
||||
|
||||
exec := NewXAIWebsocketsExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth-warmup",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"websockets": "true",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "grok-4.3",
|
||||
Payload: []byte(`{"model":"grok-4.3","generate":false,"input":[{"type":"message","role":"user","content":"warm up"}]}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
ResponseFormat: sdktranslator.FormatOpenAIResponse,
|
||||
}
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
|
||||
result, err := exec.ExecuteStream(ctx, auth, req, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case payload := <-capturedPayload:
|
||||
if got := gjson.GetBytes(payload, "generate").Bool(); got {
|
||||
t.Fatalf("generate = true, want false; payload=%s", payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %q, want response.create; payload=%s", got, payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "store").Bool(); !got {
|
||||
t.Fatalf("store = false, want true; payload=%s", payload)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for upstream websocket payload")
|
||||
}
|
||||
|
||||
var gotTypes []string
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-result.Chunks:
|
||||
if !ok {
|
||||
if len(gotTypes) != 2 {
|
||||
t.Fatalf("event types = %v, want response.created and response.completed", gotTypes)
|
||||
}
|
||||
return
|
||||
}
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error = %v", chunk.Err)
|
||||
}
|
||||
gotTypes = append(gotTypes, gjson.GetBytes(bytes.TrimSpace(chunk.Payload), "type").String())
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("timed out waiting for warmup stream to close; event types so far: %v", gotTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIWebsocketsExecuteStreamStopsOnBareErrorPayload(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
releaseServer := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
t.Errorf("read upstream websocket message: %v", errRead)
|
||||
return
|
||||
}
|
||||
payload := []byte(`{"error":{"message":"Request validation error: {\"code\":\"400\",\"error\":\"Argument not supported: instructions and previous_response_id together\"}","type":"api_error"}}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payload); errWrite != nil {
|
||||
t.Errorf("write error websocket message: %v", errWrite)
|
||||
return
|
||||
}
|
||||
<-releaseServer
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(releaseServer)
|
||||
|
||||
exec := NewXAIWebsocketsExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth-error",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"websockets": "true",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "grok-4.3",
|
||||
Payload: []byte(`{"model":"grok-4.3","input":"hello"}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
ResponseFormat: sdktranslator.FormatOpenAIResponse,
|
||||
}
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
|
||||
result, err := exec.ExecuteStream(ctx, auth, req, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case chunk, ok := <-result.Chunks:
|
||||
if !ok {
|
||||
t.Fatal("stream closed before error chunk")
|
||||
}
|
||||
if chunk.Err == nil {
|
||||
t.Fatalf("chunk error = nil, want upstream error; payload=%s", chunk.Payload)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for bare upstream error")
|
||||
}
|
||||
}
|
||||
@@ -228,9 +228,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
defer close(wsDone)
|
||||
|
||||
if h != nil && h.AuthManager != nil {
|
||||
if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil {
|
||||
type upstreamDisconnectSubscriber interface {
|
||||
UpstreamDisconnectChan(sessionID string) <-chan error
|
||||
type upstreamDisconnectSubscriber interface {
|
||||
UpstreamDisconnectChan(sessionID string) <-chan error
|
||||
}
|
||||
for _, provider := range []string{"codex", "xai"} {
|
||||
exec, ok := h.AuthManager.Executor(provider)
|
||||
if !ok || exec == nil {
|
||||
continue
|
||||
}
|
||||
if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil {
|
||||
disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID)
|
||||
@@ -315,13 +319,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName)
|
||||
useUpstreamWebsocketPassthrough := h.responsesWebsocketUsesUpstreamWebsocketPassthrough(requestModelName)
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
allowCompactionReplayBypass := false
|
||||
if !useCodexWebsocketPassthrough {
|
||||
if !useUpstreamWebsocketPassthrough {
|
||||
if pinnedAuthID != "" {
|
||||
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
allowIncrementalInputWithPreviousResponseID = responsesWebsocketAuthSupportsIncrementalInput(pinnedAuth)
|
||||
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
||||
}
|
||||
} else {
|
||||
@@ -336,7 +340,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
var requestJSON []byte
|
||||
var updatedLastRequest []byte
|
||||
var errMsg *interfaces.ErrorMessage
|
||||
if useCodexWebsocketPassthrough {
|
||||
if useUpstreamWebsocketPassthrough {
|
||||
requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName)
|
||||
} else {
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
|
||||
@@ -371,7 +375,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if !useUpstreamWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||
requestJSON = updated
|
||||
}
|
||||
@@ -394,7 +398,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
previousLastResponseID := lastResponseID
|
||||
previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...)
|
||||
forcedTranscriptReplay := forceTranscriptReplayNextRequest
|
||||
if useCodexWebsocketPassthrough {
|
||||
if useUpstreamWebsocketPassthrough {
|
||||
if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" {
|
||||
passthroughModelName = modelName
|
||||
}
|
||||
@@ -443,7 +447,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
|
||||
pinnedAuthID = ""
|
||||
forceTranscriptReplayNextRequest = true
|
||||
if useCodexWebsocketPassthrough {
|
||||
if useUpstreamWebsocketPassthrough {
|
||||
passthroughModelName = ""
|
||||
} else {
|
||||
lastRequest = previousLastRequest
|
||||
@@ -453,7 +457,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !useCodexWebsocketPassthrough {
|
||||
if !useUpstreamWebsocketPassthrough {
|
||||
lastResponseOutput = completedOutput
|
||||
lastResponseID = strings.TrimSpace(completedResponseID)
|
||||
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
|
||||
@@ -917,7 +921,7 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
||||
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
|
||||
for _, auth := range auths {
|
||||
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
if responsesWebsocketAuthSupportsIncrementalInput(auth) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -961,29 +965,47 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool {
|
||||
return h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName)
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName string) bool {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if h == nil || h.AuthManager == nil || modelName == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := h.AuthManager.Executor("codex"); !ok {
|
||||
return false
|
||||
}
|
||||
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
|
||||
if len(auths) == 0 {
|
||||
return false
|
||||
}
|
||||
provider := ""
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
authProvider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if authProvider != "codex" && authProvider != "xai" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
provider = authProvider
|
||||
if _, ok := h.AuthManager.Executor(provider); !ok {
|
||||
return false
|
||||
}
|
||||
} else if authProvider != provider {
|
||||
return false
|
||||
}
|
||||
if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
return provider != ""
|
||||
}
|
||||
|
||||
func responsesWebsocketAuthSupportsIncrementalInput(auth *coreauth.Auth) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
return websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata)
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) {
|
||||
|
||||
@@ -29,6 +29,11 @@ type websocketCaptureExecutor struct {
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
type websocketProviderCaptureExecutor struct {
|
||||
provider string
|
||||
websocketCaptureExecutor
|
||||
}
|
||||
|
||||
type websocketCompactionCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
streamPayloads [][]byte
|
||||
@@ -85,6 +90,7 @@ type websocketBootstrapFallbackExecutor struct {
|
||||
|
||||
type websocketDirectCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
provider string
|
||||
authIDs []string
|
||||
payloads [][]byte
|
||||
done chan struct{}
|
||||
@@ -164,7 +170,12 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" }
|
||||
func (e *websocketDirectCaptureExecutor) Identifier() string {
|
||||
if e != nil && strings.TrimSpace(e.provider) != "" {
|
||||
return strings.TrimSpace(e.provider)
|
||||
}
|
||||
return "codex"
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
@@ -403,6 +414,13 @@ func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte {
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketProviderCaptureExecutor) Identifier() string {
|
||||
if e != nil && strings.TrimSpace(e.provider) != "" {
|
||||
return strings.TrimSpace(e.provider)
|
||||
}
|
||||
return "test-provider"
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
@@ -1641,6 +1659,94 @@ func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithou
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketXAIWebsocketPassthroughCarriesPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
modelName := "xai-websocket-passthrough-model"
|
||||
executor := &websocketDirectCaptureExecutor{provider: "xai", done: make(chan struct{})}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-xai-ws",
|
||||
Provider: "xai",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
firstRequest := []byte(fmt.Sprintf(`{"type":"response.create","model":%q,"input":[{"type":"message","id":"msg-1","role":"user","content":"first"}]}`, modelName))
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil {
|
||||
t.Fatalf("write first websocket message: %v", errWrite)
|
||||
}
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
t.Fatalf("read first websocket response: %v", errRead)
|
||||
}
|
||||
|
||||
secondRequest := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-2","role":"user","content":"second"}]}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, secondRequest); errWrite != nil {
|
||||
t.Fatalf("write second websocket message: %v", errWrite)
|
||||
}
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
t.Fatalf("read second websocket response: %v", errRead)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-executor.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for websocket passthrough")
|
||||
}
|
||||
|
||||
payloads := executor.Payloads()
|
||||
if len(payloads) != 2 {
|
||||
t.Fatalf("xai websocket payload count = %d, want 2", len(payloads))
|
||||
}
|
||||
secondPayload := payloads[1]
|
||||
if got := gjson.GetBytes(secondPayload, "type").String(); got != wsRequestTypeCreate {
|
||||
t.Fatalf("second xai passthrough type = %s, want %s: %s", got, wsRequestTypeCreate, secondPayload)
|
||||
}
|
||||
if got := gjson.GetBytes(secondPayload, "model").String(); got != modelName {
|
||||
t.Fatalf("second xai payload model = %s, want %s", got, modelName)
|
||||
}
|
||||
if got := gjson.GetBytes(secondPayload, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("second xai previous_response_id = %s, want resp-1: %s", got, secondPayload)
|
||||
}
|
||||
input := gjson.GetBytes(secondPayload, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("second xai passthrough input len = %d, want 1: %s", len(input), secondPayload)
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("second xai passthrough input must contain only the new turn: %s", secondPayload)
|
||||
}
|
||||
if bytes.Contains(secondPayload, []byte(`"id":"msg-1"`)) || bytes.Contains(secondPayload, []byte(`"id":"out-1"`)) {
|
||||
t.Fatalf("second xai passthrough payload contains stale transcript state: %s", secondPayload)
|
||||
}
|
||||
authIDs := executor.AuthIDs()
|
||||
if len(authIDs) != 2 || authIDs[0] != "auth-xai-ws" || authIDs[1] != "auth-xai-ws" {
|
||||
t.Fatalf("xai websocket auth IDs = %v, want [auth-xai-ws auth-xai-ws]", authIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
@@ -1664,6 +1770,56 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForXAI(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-xai-ws",
|
||||
Provider: "xai",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "xai-test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.websocketUpstreamSupportsIncrementalInputForModel("xai-test-model") {
|
||||
t.Fatalf("expected xai websocket upstream to support previous_response_id incremental input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketUsesUpstreamWebsocketPassthroughForXAI(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
executor := &websocketProviderCaptureExecutor{provider: "xai"}
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
modelName := "xai-passthrough-model"
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-xai-ws",
|
||||
Provider: "xai",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: modelName}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.responsesWebsocketUsesUpstreamWebsocketPassthrough(modelName) {
|
||||
t.Fatalf("expected xai websocket upstream passthrough for %s", modelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
|
||||
@@ -249,7 +249,7 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
modelKey := canonicalModelKey(model)
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
|
||||
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerPrefersWebsocketTransport(providerKey) && pinnedAuthID == ""
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -284,6 +284,15 @@ func (s *authScheduler) pickSingleWithStrategy(ctx context.Context, provider, mo
|
||||
return nil, shard.unavailableErrorLocked(provider, model, predicate)
|
||||
}
|
||||
|
||||
func providerPrefersWebsocketTransport(providerKey string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(providerKey)) {
|
||||
case "codex", "xai":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// pickMixed returns the next auth and provider for a mixed-provider request.
|
||||
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
|
||||
return s.pickMixedWithStrategy(ctx, providers, model, opts, tried, schedulerStrategyCurrent)
|
||||
|
||||
@@ -237,6 +237,32 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_XAIWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "xai-http", Provider: "xai"},
|
||||
&Auth{ID: "xai-ws-a", Provider: "xai", Attributes: map[string]string{"websockets": "true"}},
|
||||
&Auth{ID: "xai-ws-b", Provider: "xai", Attributes: map[string]string{"websockets": "true"}},
|
||||
)
|
||||
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
want := []string{"xai-ws-a", "xai-ws-b", "xai-ws-a"}
|
||||
for index, wantID := range want {
|
||||
got, errPick := scheduler.pickSingle(ctx, "xai", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantID {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledAcrossPriorities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -740,6 +740,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
if strings.EqualFold(provider, "codex") {
|
||||
executor.CloseCodexWebsocketSessionsForAuthID(id, "auth_removed")
|
||||
}
|
||||
if strings.EqualFold(provider, "xai") {
|
||||
executor.CloseXAIWebsocketSessionsForAuthID(id, "auth_removed")
|
||||
}
|
||||
s.syncPluginRuntime(ctx)
|
||||
}
|
||||
|
||||
@@ -948,7 +951,7 @@ func (s *Service) registerExecutorForAuth(a *coreauth.Auth, forceReplace bool) {
|
||||
case "kimi":
|
||||
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
||||
case "xai":
|
||||
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
|
||||
s.coreManager.RegisterExecutor(executor.NewXAIAutoExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
|
||||
@@ -3,6 +3,7 @@ package cliproxy
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
)
|
||||
@@ -62,3 +63,25 @@ func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
|
||||
t.Fatal("expected codex executor replacement in force mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureExecutorsForAuth_XAIBindsAutoExecutor(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "xai-auth-1",
|
||||
Provider: "xai",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
|
||||
gotExecutor, ok := service.coreManager.Executor("xai")
|
||||
if !ok || gotExecutor == nil {
|
||||
t.Fatal("expected xai executor after bind")
|
||||
}
|
||||
if _, ok := gotExecutor.(*executor.XAIAutoExecutor); !ok {
|
||||
t.Fatalf("xai executor type = %T, want *executor.XAIAutoExecutor", gotExecutor)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user