feat(websockets): add Codex websocket passthrough support with tests

- Implemented `websocketDirectCaptureExecutor` for Codex websocket passthrough functionality.
- Added logic to bypass incremental state handling for passthrough models.
- Updated normalization, compaction, and replay handling to support passthrough mode.
- Introduced `responsesWebsocketUsesCodexWebsocketPassthrough` utility for model-specific passthrough determination.
- Expanded test coverage for websocket passthrough scenarios, including compaction and response validation.
This commit is contained in:
Luis Pater
2026-06-15 02:31:05 +08:00
parent 7de9757c82
commit 56988aea0f
2 changed files with 265 additions and 46 deletions

View File

@@ -272,6 +272,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
lastResponseID := ""
var lastResponsePendingToolCallIDs []string
pinnedAuthID := ""
passthroughModelName := ""
sessionAuthByID := func(authID string) (*coreauth.Auth, bool) {
if h == nil || h.AuthManager == nil {
return nil, false
@@ -307,47 +308,47 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
wsTimelineLog.BeginRequest()
wsTimelineLog.Append("request", payload, time.Now())
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = passthroughModelName
}
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
useCodexWebsocketPassthrough := h.responsesWebsocketUsesCodexWebsocketPassthrough(requestModelName)
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" {
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
if forceTranscriptReplayNextRequest {
allowIncrementalInputWithPreviousResponseID = false
}
allowCompactionReplayBypass := false
if pinnedAuthID != "" {
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
if !useCodexWebsocketPassthrough {
if pinnedAuthID != "" {
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
}
} else {
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if forceTranscriptReplayNextRequest {
allowIncrementalInputWithPreviousResponseID = false
}
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
}
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
payload,
lastRequest,
lastResponseOutput,
lastResponseID,
lastResponsePendingToolCallIDs,
allowIncrementalInputWithPreviousResponseID,
allowCompactionReplayBypass,
)
if useCodexWebsocketPassthrough {
requestJSON, errMsg = normalizeResponsesWebsocketPassthroughRequest(payload, requestModelName)
} else {
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
payload,
lastRequest,
lastResponseOutput,
lastResponseID,
lastResponsePendingToolCallIDs,
allowIncrementalInputWithPreviousResponseID,
allowCompactionReplayBypass,
)
}
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
@@ -370,7 +371,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if !useCodexWebsocketPassthrough && shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
@@ -388,17 +389,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
continue
}
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
previousLastRequest := bytes.Clone(lastRequest)
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
previousLastResponseID := lastResponseID
previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...)
forcedTranscriptReplay := forceTranscriptReplayNextRequest
lastRequest = updatedLastRequest
if forcedTranscriptReplay {
forceTranscriptReplayNextRequest = false
if useCodexWebsocketPassthrough {
if modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String()); modelName != "" {
passthroughModelName = modelName
}
if forcedTranscriptReplay {
forceTranscriptReplayNextRequest = false
}
} else {
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
lastRequest = updatedLastRequest
if forcedTranscriptReplay {
forceTranscriptReplayNextRequest = false
}
}
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -433,15 +443,21 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
pinnedAuthID = ""
forceTranscriptReplayNextRequest = true
lastRequest = previousLastRequest
lastResponseOutput = previousLastResponseOutput
lastResponseID = previousLastResponseID
lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
if useCodexWebsocketPassthrough {
passthroughModelName = ""
} else {
lastRequest = previousLastRequest
lastResponseOutput = previousLastResponseOutput
lastResponseID = previousLastResponseID
lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
}
continue
}
lastResponseOutput = completedOutput
lastResponseID = strings.TrimSpace(completedResponseID)
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
if !useCodexWebsocketPassthrough {
lastResponseOutput = completedOutput
lastResponseID = strings.TrimSpace(completedResponseID)
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
}
}
}
@@ -944,6 +960,65 @@ func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(mod
return available, modelKey
}
func (h *OpenAIResponsesAPIHandler) responsesWebsocketUsesCodexWebsocketPassthrough(modelName string) bool {
modelName = strings.TrimSpace(modelName)
if h == nil || h.AuthManager == nil || modelName == "" {
return false
}
if _, ok := h.AuthManager.Executor("codex"); !ok {
return false
}
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
if len(auths) == 0 {
return false
}
for _, auth := range auths {
if auth == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return false
}
if !websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return false
}
}
return true
}
func normalizeResponsesWebsocketPassthroughRequest(rawJSON []byte, modelName string) ([]byte, *interfaces.ErrorMessage) {
if !json.Valid(rawJSON) {
return nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid websocket request JSON"),
}
}
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate, wsRequestTypeAppend:
default:
return nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
}
}
normalized := bytes.Clone(rawJSON)
if strings.TrimSpace(gjson.GetBytes(normalized, "model").String()) == "" {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
return nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("missing model in response.create request"),
}
}
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, nil
}
func responsesWebsocketResolvedModelName(modelName string) string {
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {

View File

@@ -83,6 +83,14 @@ type websocketBootstrapFallbackExecutor struct {
payloads map[string][][]byte
}
type websocketDirectCaptureExecutor struct {
mu sync.Mutex
authIDs []string
payloads [][]byte
done chan struct{}
doneOnce sync.Once
}
type websocketPinnedFailoverStatusError struct {
status int
msg string
@@ -156,6 +164,63 @@ func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
return out
}
func (e *websocketDirectCaptureExecutor) Identifier() string { return "codex" }
func (e *websocketDirectCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketDirectCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
e.authIDs = append(e.authIDs, authID)
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
count := len(e.payloads)
e.mu.Unlock()
chunks := make(chan coreexecutor.StreamChunk, 1)
responseID := fmt.Sprintf("resp-%d", count)
chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":%q,"output":[{"type":"message","id":"out-%d"}]}}`, responseID, count))}
close(chunks)
if count >= 2 && e.done != nil {
e.doneOnce.Do(func() {
close(e.done)
})
}
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketDirectCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketDirectCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketDirectCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func (e *websocketDirectCaptureExecutor) Payloads() [][]byte {
e.mu.Lock()
defer e.mu.Unlock()
out := make([][]byte, len(e.payloads))
for i := range e.payloads {
out[i] = bytes.Clone(e.payloads[i])
}
return out
}
func (e *websocketDirectCaptureExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
return append([]string(nil), e.authIDs...)
}
type websocketUpstreamDisconnectExecutor struct {
mu sync.Mutex
subscribed chan string
@@ -1497,6 +1562,85 @@ func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) {
}
}
func TestResponsesWebsocketCodexWebsocketPassthroughPassesCompactedRequestWithoutTranscriptMerge(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketDirectCaptureExecutor{done: make(chan struct{})}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{
ID: "auth-ws",
Provider: "codex",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
firstRequest := []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","role":"user","content":"first"}]}`)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() { _ = conn.Close() }()
if errWrite := conn.WriteMessage(websocket.TextMessage, firstRequest); errWrite != nil {
t.Fatalf("write first websocket message: %v", errWrite)
}
if _, _, errRead := conn.ReadMessage(); errRead != nil {
t.Fatalf("read first websocket response: %v", errRead)
}
compactedRequest := []byte(`{"type":"response.create","input":[{"type":"compaction_summary","summary":"compressed history"},{"type":"message","role":"user","content":"after compaction"}]}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, compactedRequest); errWrite != nil {
t.Fatalf("write compacted websocket message: %v", errWrite)
}
if _, _, errRead := conn.ReadMessage(); errRead != nil {
t.Fatalf("read compacted websocket response: %v", errRead)
}
select {
case <-executor.done:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for websocket passthrough")
}
payloads := executor.Payloads()
if len(payloads) != 2 {
t.Fatalf("passthrough payload count = %d, want 2", len(payloads))
}
if got := gjson.GetBytes(payloads[0], "input").Raw; got != gjson.GetBytes(firstRequest, "input").Raw {
t.Fatalf("first passthrough input = %s, want %s", got, gjson.GetBytes(firstRequest, "input").Raw)
}
if got := gjson.GetBytes(payloads[1], "input").Raw; got != gjson.GetBytes(compactedRequest, "input").Raw {
t.Fatalf("compacted passthrough input = %s, want %s", got, gjson.GetBytes(compactedRequest, "input").Raw)
}
if got := gjson.GetBytes(payloads[1], "model").String(); got != "test-model" {
t.Fatalf("compacted passthrough model = %s, want test-model", got)
}
if bytes.Contains(payloads[1], []byte(`"content":"first"`)) || bytes.Contains(payloads[1], []byte(`"id":"out-1"`)) {
t.Fatalf("compacted passthrough payload contains stale transcript state: %s", payloads[1])
}
authIDs := executor.AuthIDs()
if len(authIDs) != 2 || authIDs[0] != "auth-ws" || authIDs[1] != "auth-ws" {
t.Fatalf("passthrough auth IDs = %v, want [auth-ws auth-ws]", authIDs)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{