diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index eae042b9e..142719aa2 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -381,6 +381,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON) + requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON) updatedLastRequest = bytes.Clone(requestJSON) previousLastRequest := bytes.Clone(lastRequest) previousLastResponseOutput := bytes.Clone(lastResponseOutput) @@ -582,6 +583,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last if errDedupeFunctionCalls == nil { mergedInput = dedupedInput } + dedupedInput, errDedupeItemIDs := dedupeInputItemsByID(mergedInput) + if errDedupeItemIDs == nil { + mergedInput = dedupedInput + } normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") if errDelete != nil { @@ -697,6 +702,64 @@ func dedupeFunctionCallsByCallID(rawArray string) (string, error) { return string(out), nil } +func dedupeResponsesWebsocketInputItemsByID(payload []byte) []byte { + input := gjson.GetBytes(payload, "input") + if !input.Exists() || !input.IsArray() { + return payload + } + dedupedInput, errDedupe := dedupeInputItemsByID(input.Raw) + if errDedupe != nil || dedupedInput == input.Raw { + return payload + } + updated, errSet := sjson.SetRawBytes(payload, "input", []byte(dedupedInput)) + if errSet != nil { + return payload + } + return updated +} + +func dedupeInputItemsByID(rawArray string) (string, error) { + rawArray = strings.TrimSpace(rawArray) + if rawArray == "" { + return "[]", nil + } + var items []json.RawMessage + if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil { + return "", errUnmarshal + } + + lastIndexByID := make(map[string]int, len(items)) + for i, item := range items { + if len(item) == 0 { + continue + } + itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String()) + if itemID != "" { + lastIndexByID[itemID] = i + } + } + + filtered := make([]json.RawMessage, 0, len(items)) + for i, item := range items { + if len(item) == 0 { + continue + } + itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String()) + if itemID != "" { + if lastIndexByID[itemID] != i { + continue + } + } + filtered = append(filtered, item) + } + + out, errMarshal := json.Marshal(filtered) + if errMarshal != nil { + return "", errMarshal + } + return string(out), nil +} + func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { if len(attributes) > 0 { if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index d37c783db..9f23af82d 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -1603,6 +1603,30 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t } } +func TestNormalizeResponsesWebsocketRequestDropsDuplicateInputItemsByID(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1","role":"user"}]}`) + lastResponseOutput := []byte(`[ + {"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"} + ]`) + raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call","id":"fc-1","call_id":"call-2","name":"tool"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-2"}]}`) + + normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, true) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + + items := gjson.GetBytes(normalized, "input").Array() + if len(items) != 3 { + t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized) + } + if items[0].Get("id").String() != "msg-1" || + items[1].Get("id").String() != "fc-1" || + items[1].Get("call_id").String() != "call-2" || + items[2].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected merged input order: %s", normalized) + } +} + func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) { lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`) lastResponseOutput := []byte(`[ @@ -1654,6 +1678,22 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateCustomToolCallsByCallID } } +func TestDedupeResponsesWebsocketInputItemsByIDAfterRepair(t *testing.T) { + payload := []byte(`{"input":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"tool"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-2","name":"tool"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-2"}]}`) + + deduped := dedupeResponsesWebsocketInputItemsByID(payload) + + items := gjson.GetBytes(deduped, "input").Array() + if len(items) != 2 { + t.Fatalf("deduped input len = %d, want 2: %s", len(items), deduped) + } + if items[0].Get("id").String() != "ctc-1" || + items[0].Get("call_id").String() != "call-2" || + items[1].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected deduped input: %s", deduped) + } +} + func TestResponsesWebsocketCompactionResetsTurnStateOnCustomToolTranscriptReplacement(t *testing.T) { gin.SetMode(gin.TestMode)