From de280d993d08a1612ee96b969988cd741ca3b71f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 27 May 2026 01:01:57 +0800 Subject: [PATCH] feat(websockets): refine incremental repair logic for tool call responses - Updated WebSocket response repair tests to validate incremental preservation of response calls and outputs. - Added new test cases for custom tool responses ensuring accurate handling of output cache and call cache. - Refactored `repairResponsesWebsocketToolCallsWithCaches` to handle orphan outputs more consistently. - Adjusted input filtering logic for clearer incremental repair behavior. Closes: #3569 --- .../openai/openai_responses_websocket_test.go | 65 ++++++++++++++++--- ...nai_responses_websocket_toolcall_repair.go | 15 +++-- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 8b945b50c..d37c783db 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -691,7 +691,7 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *te } } -func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOutput(t *testing.T) { +func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseOutputIncremental(t *testing.T) { outputCache := newWebsocketToolOutputCache(time.Minute, 10) callCache := newWebsocketToolOutputCache(time.Minute, 10) sessionKey := "session-1" @@ -705,17 +705,39 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOu t.Fatalf("previous_response_id = %q, want resp-latest", got) } input := gjson.GetBytes(repaired, "input").Array() - if len(input) != 3 { - t.Fatalf("repaired input len = %d, want 3: %s", len(input), repaired) + if len(input) != 2 { + t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired) + } + if input[0].Get("type").String() != "function_call_output" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected output item: %s", input[0].Raw) + } + if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[1].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCallIncremental(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + outputCache.record(sessionKey, "call-1", []byte(`{"type":"function_call_output","call_id":"call-1","id":"tool-out-1","output":"ok"}`)) + + raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" { + t.Fatalf("previous_response_id = %q, want resp-latest", got) + } + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 2 { + t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired) } if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" { - t.Fatalf("missing inserted call: %s", input[0].Raw) + t.Fatalf("unexpected call item: %s", input[0].Raw) } - if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" { - t.Fatalf("unexpected output item: %s", input[1].Raw) - } - if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" { - t.Fatalf("unexpected trailing item: %s", input[2].Raw) + if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[1].Raw) } } @@ -805,6 +827,31 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOu } } +func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCustomToolOutputIncremental(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + callCache.record(sessionKey, "call-1", []byte(`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}`)) + + raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" { + t.Fatalf("previous_response_id = %q, want resp-latest", got) + } + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 2 { + t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired) + } + if input[0].Get("type").String() != "custom_tool_call_output" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected output item: %s", input[0].Raw) + } + if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[1].Raw) + } +} + func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing(t *testing.T) { outputCache := newWebsocketToolOutputCache(time.Minute, 10) callCache := newWebsocketToolOutputCache(time.Minute, 10) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go index c521bec04..22219a8ab 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go @@ -305,6 +305,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } + if allowOrphanOutputs { + filtered = append(filtered, item) + continue + } + if callCache != nil { if cached, ok := callCache.get(sessionKey, callID); ok { if _, already := insertedCalls[callID]; !already { @@ -317,11 +322,6 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa } } - if allowOrphanOutputs { - filtered = append(filtered, item) - continue - } - // Drop orphaned function_call_output items; upstream rejects transcripts with missing calls. continue } @@ -341,6 +341,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } + if allowOrphanOutputs { + filtered = append(filtered, item) + continue + } + if cached, ok := outputCache.get(sessionKey, callID); ok { filtered = append(filtered, item) filtered = append(filtered, cached)