Merge pull request #3620 from iBenzene/fix/responses-input-id-dedupe

fix(openai): dedupe response websocket input item IDs
This commit is contained in:
Luis Pater
2026-05-30 03:27:36 +08:00
committed by GitHub
2 changed files with 103 additions and 0 deletions

View File

@@ -381,6 +381,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
previousLastRequest := bytes.Clone(lastRequest)
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
@@ -582,6 +583,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
if errDedupeFunctionCalls == nil {
mergedInput = dedupedInput
}
dedupedInput, errDedupeItemIDs := dedupeInputItemsByID(mergedInput)
if errDedupeItemIDs == nil {
mergedInput = dedupedInput
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
@@ -697,6 +702,64 @@ func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
return string(out), nil
}
func dedupeResponsesWebsocketInputItemsByID(payload []byte) []byte {
input := gjson.GetBytes(payload, "input")
if !input.Exists() || !input.IsArray() {
return payload
}
dedupedInput, errDedupe := dedupeInputItemsByID(input.Raw)
if errDedupe != nil || dedupedInput == input.Raw {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(dedupedInput))
if errSet != nil {
return payload
}
return updated
}
func dedupeInputItemsByID(rawArray string) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
lastIndexByID := make(map[string]int, len(items))
for i, item := range items {
if len(item) == 0 {
continue
}
itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String())
if itemID != "" {
lastIndexByID[itemID] = i
}
}
filtered := make([]json.RawMessage, 0, len(items))
for i, item := range items {
if len(item) == 0 {
continue
}
itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String())
if itemID != "" {
if lastIndexByID[itemID] != i {
continue
}
}
filtered = append(filtered, item)
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {

View File

@@ -1603,6 +1603,30 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t
}
}
func TestNormalizeResponsesWebsocketRequestDropsDuplicateInputItemsByID(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1","role":"user"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call","id":"fc-1","call_id":"call-2","name":"tool"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-2"}]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, true)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "msg-1" ||
items[1].Get("id").String() != "fc-1" ||
items[1].Get("call_id").String() != "call-2" ||
items[2].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order: %s", normalized)
}
}
func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
@@ -1654,6 +1678,22 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateCustomToolCallsByCallID
}
}
func TestDedupeResponsesWebsocketInputItemsByIDAfterRepair(t *testing.T) {
payload := []byte(`{"input":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"tool"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-2","name":"tool"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-2"}]}`)
deduped := dedupeResponsesWebsocketInputItemsByID(payload)
items := gjson.GetBytes(deduped, "input").Array()
if len(items) != 2 {
t.Fatalf("deduped input len = %d, want 2: %s", len(items), deduped)
}
if items[0].Get("id").String() != "ctc-1" ||
items[0].Get("call_id").String() != "call-2" ||
items[1].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected deduped input: %s", deduped)
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnCustomToolTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)