feat(websockets): refine incremental repair logic for tool call responses

- Updated WebSocket response repair tests to validate incremental preservation of response calls and outputs.
- Added new test cases for custom tool responses ensuring accurate handling of output cache and call cache.
- Refactored `repairResponsesWebsocketToolCallsWithCaches` to handle orphan outputs more consistently.
- Adjusted input filtering logic for clearer incremental repair behavior.

Closes: #3569
This commit is contained in:
Luis Pater
2026-05-27 01:01:57 +08:00
parent e399edd3cc
commit de280d993d
2 changed files with 66 additions and 14 deletions

View File

@@ -691,7 +691,7 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *te
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOutput(t *testing.T) {
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseOutputIncremental(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
@@ -705,17 +705,39 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOu
t.Fatalf("previous_response_id = %q, want resp-latest", got)
}
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3: %s", len(input), repaired)
if len(input) != 2 {
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
}
if input[0].Get("type").String() != "function_call_output" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCallIncremental(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
outputCache.record(sessionKey, "call-1", []byte(`{"type":"function_call_output","call_id":"call-1","id":"tool-out-1","output":"ok"}`))
raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" {
t.Fatalf("previous_response_id = %q, want resp-latest", got)
}
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 2 {
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted call: %s", input[0].Raw)
t.Fatalf("unexpected call item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
}
}
@@ -805,6 +827,31 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOu
}
}
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCustomToolOutputIncremental(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
callCache.record(sessionKey, "call-1", []byte(`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}`))
raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" {
t.Fatalf("previous_response_id = %q, want resp-latest", got)
}
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 2 {
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
}
if input[0].Get("type").String() != "custom_tool_call_output" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)

View File

@@ -305,6 +305,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
if allowOrphanOutputs {
filtered = append(filtered, item)
continue
}
if callCache != nil {
if cached, ok := callCache.get(sessionKey, callID); ok {
if _, already := insertedCalls[callID]; !already {
@@ -317,11 +322,6 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
}
}
if allowOrphanOutputs {
filtered = append(filtered, item)
continue
}
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue
}
@@ -341,6 +341,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue
}
if allowOrphanOutputs {
filtered = append(filtered, item)
continue
}
if cached, ok := outputCache.get(sessionKey, callID); ok {
filtered = append(filtered, item)
filtered = append(filtered, cached)