mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-22 15:03:35 +08:00
feat(websockets): add Codex websocket passthrough support with tests
- Implemented `websocketDirectCaptureExecutor` for Codex websocket passthrough functionality. - Added logic to bypass incremental state handling for passthrough models. - Updated normalization, compaction, and replay handling to support passthrough mode. - Introduced `responsesWebsocketUsesCodexWebsocketPassthrough` utility for model-specific passthrough determination. - Expanded test coverage for websocket passthrough scenarios, including compaction and response validation.
This commit is contained in:
@@ -272,6 +272,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
lastResponseID := ""
|
||||
var lastResponsePendingToolCallIDs []string
|
||||
pinnedAuthID := ""
|
||||
passthroughModelName := ""
|
||||
sessionAuthByID := func(authID string) (*coreauth.Auth, bool) {
|
||||
if h == nil || h.AuthManager == nil {
|
||||
return nil, false
|
||||
@@ -307,47 +308,47 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
wsTimelineLog.BeginRequest()
|
||||
wsTimelineLog.Append("request", payload, time.Now())
|
||||
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = passthroughModelName
|
||||
}
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName)
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
if pinnedAuthID != "" {
|
||||
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
} else {
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
}
|
||||
if forceTranscriptReplayNextRequest {
|
||||
allowIncrementalInputWithPreviousResponseID = false
|
||||
}
|
||||
|
||||
allowCompactionReplayBypass := false
|
||||
if pinnedAuthID != "" {
|
||||
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
||||
if !useCodexWebsocketPassthrough {
|
||||
if pinnedAuthID != "" {
|
||||
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
||||
}
|
||||
} else {
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
|
||||
}
|
||||
} else {
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if forceTranscriptReplayNextRequest {
|
||||
allowIncrementalInputWithPreviousResponseID = false
|
||||
}
|
||||
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
var updatedLastRequest []byte
|
||||
var errMsg *interfaces.ErrorMessage
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
|
||||
payload,
|
||||
lastRequest,
|
||||
lastResponseOutput,
|
||||
lastResponseID,
|
||||
lastResponsePendingToolCallIDs,
|
||||
allowIncrementalInputWithPreviousResponseID,
|
||||
allowCompactionReplayBypass,
|
||||
)
|
||||
if useCodexWebsocketPassthrough {
|
||||
requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName)
|
||||
} else {
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
|
||||
payload,
|
||||
lastRequest,
|
||||
lastResponseOutput,
|
||||
lastResponseID,
|
||||
lastResponsePendingToolCallIDs,
|
||||
allowIncrementalInputWithPreviousResponseID,
|
||||
allowCompactionReplayBypass,
|
||||
)
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
@@ -370,7 +371,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||
requestJSON = updated
|
||||
}
|
||||
@@ -388,17 +389,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
|
||||
updatedLastRequest = bytes.Clone(requestJSON)
|
||||
previousLastRequest := bytes.Clone(lastRequest)
|
||||
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
|
||||
previousLastResponseID := lastResponseID
|
||||
previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...)
|
||||
forcedTranscriptReplay := forceTranscriptReplayNextRequest
|
||||
lastRequest = updatedLastRequest
|
||||
if forcedTranscriptReplay {
|
||||
forceTranscriptReplayNextRequest = false
|
||||
if useCodexWebsocketPassthrough {
|
||||
if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" {
|
||||
passthroughModelName = modelName
|
||||
}
|
||||
if forcedTranscriptReplay {
|
||||
forceTranscriptReplayNextRequest = false
|
||||
}
|
||||
} else {
|
||||
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
|
||||
updatedLastRequest = bytes.Clone(requestJSON)
|
||||
lastRequest = updatedLastRequest
|
||||
if forcedTranscriptReplay {
|
||||
forceTranscriptReplayNextRequest = false
|
||||
}
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -433,15 +443,21 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
|
||||
pinnedAuthID = ""
|
||||
forceTranscriptReplayNextRequest = true
|
||||
lastRequest = previousLastRequest
|
||||
lastResponseOutput = previousLastResponseOutput
|
||||
lastResponseID = previousLastResponseID
|
||||
lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
|
||||
if useCodexWebsocketPassthrough {
|
||||
passthroughModelName = ""
|
||||
} else {
|
||||
lastRequest = previousLastRequest
|
||||
lastResponseOutput = previousLastResponseOutput
|
||||
lastResponseID = previousLastResponseID
|
||||
lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastResponseOutput = completedOutput
|
||||
lastResponseID = strings.TrimSpace(completedResponseID)
|
||||
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
|
||||
if !useCodexWebsocketPassthrough {
|
||||
lastResponseOutput = completedOutput
|
||||
lastResponseID = strings.TrimSpace(completedResponseID)
|
||||
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -944,6 +960,65 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod
|
||||
return available, modelKey
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if h == nil || h.AuthManager == nil || modelName == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := h.AuthManager.Executor("codex"); !ok {
|
||||
return false
|
||||
}
|
||||
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
|
||||
if len(auths) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
return false
|
||||
}
|
||||
if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) {
|
||||
if !json.Valid(rawJSON) {
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid websocket request JSON"),
|
||||
}
|
||||
}
|
||||
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
switch requestType {
|
||||
case wsRequestTypeCreate, wsRequestTypeAppend:
|
||||
default:
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||
}
|
||||
}
|
||||
|
||||
normalized := bytes.Clone(rawJSON)
|
||||
if strings.TrimSpace(gjson.GetBytes(normalized, "model").String()) == "" {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("missing model in response.create request"),
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func responsesWebsocketResolvedModelName(modelName string) string {
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
|
||||
@@ -83,6 +83,14 @@ type websocketBootstrapFallbackExecutor struct {
|
||||
payloads map[string][][]byte
|
||||
}
|
||||
|
||||
type websocketDirectCaptureExecutor struct {
|
||||
mu sync.Mutex
|
||||
authIDs []string
|
||||
payloads [][]byte
|
||||
done chan struct{}
|
||||
doneOnce sync.Once
|
||||
}
|
||||
|
||||
type websocketPinnedFailoverStatusError struct {
|
||||
status int
|
||||
msg string
|
||||
@@ -156,6 +164,63 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
authID := ""
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.authIDs = append(e.authIDs, authID)
|
||||
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||
count := len(e.payloads)
|
||||
e.mu.Unlock()
|
||||
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
responseID := fmt.Sprintf("resp-%d", count)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":%q,"output":[{"type":"message","id":"out-%d"}]}}`, responseID, count))}
|
||||
close(chunks)
|
||||
if count >= 2 && e.done != nil {
|
||||
e.doneOnce.Do(func() {
|
||||
close(e.done)
|
||||
})
|
||||
}
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) Payloads() [][]byte {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([][]byte, len(e.payloads))
|
||||
for i := range e.payloads {
|
||||
out[i] = bytes.Clone(e.payloads[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *websocketDirectCaptureExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return append([]string(nil), e.authIDs...)
|
||||
}
|
||||
|
||||
type websocketUpstreamDisconnectExecutor struct {
|
||||
mu sync.Mutex
|
||||
subscribed chan string
|
||||
@@ -1497,6 +1562,85 @@ func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithoutTranscriptMerge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketDirectCaptureExecutor{done: make(chan struct{})}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
firstRequest := []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","role":"user","content":"first"}]}`)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil {
|
||||
t.Fatalf("write first websocket message: %v", errWrite)
|
||||
}
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
t.Fatalf("read first websocket response: %v", errRead)
|
||||
}
|
||||
|
||||
compactedRequest := []byte(`{"type":"response.create","input":[{"type":"compaction_summary","summary":"compressed history"},{"type":"message","role":"user","content":"after compaction"}]}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, compactedRequest); errWrite != nil {
|
||||
t.Fatalf("write compacted websocket message: %v", errWrite)
|
||||
}
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
t.Fatalf("read compacted websocket response: %v", errRead)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-executor.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for websocket passthrough")
|
||||
}
|
||||
|
||||
payloads := executor.Payloads()
|
||||
if len(payloads) != 2 {
|
||||
t.Fatalf("passthrough payload count = %d, want 2", len(payloads))
|
||||
}
|
||||
if got := gjson.GetBytes(payloads[0], "input").Raw; got != gjson.GetBytes(firstRequest, "input").Raw {
|
||||
t.Fatalf("first passthrough input = %s, want %s", got, gjson.GetBytes(firstRequest, "input").Raw)
|
||||
}
|
||||
if got := gjson.GetBytes(payloads[1], "input").Raw; got != gjson.GetBytes(compactedRequest, "input").Raw {
|
||||
t.Fatalf("compacted passthrough input = %s, want %s", got, gjson.GetBytes(compactedRequest, "input").Raw)
|
||||
}
|
||||
if got := gjson.GetBytes(payloads[1], "model").String(); got != "test-model" {
|
||||
t.Fatalf("compacted passthrough model = %s, want test-model", got)
|
||||
}
|
||||
if bytes.Contains(payloads[1], []byte(`"content":"first"`)) || bytes.Contains(payloads[1], []byte(`"id":"out-1"`)) {
|
||||
t.Fatalf("compacted passthrough payload contains stale transcript state: %s", payloads[1])
|
||||
}
|
||||
authIDs := executor.AuthIDs()
|
||||
if len(authIDs) != 2 || authIDs[0] != "auth-ws" || authIDs[1] != "auth-ws" {
|
||||
t.Fatalf("passthrough auth IDs = %v, want [auth-ws auth-ws]", authIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
|
||||
Reference in New Issue
Block a user