From 17a1f53c47c2b6d846cd4cc928428c2824cb0ce5 Mon Sep 17 00:00:00 2001 From: songyu Date: Wed, 6 May 2026 14:37:18 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9Aopenai=202=20kimi=20error=20=20=20C?= =?UTF-8?q?ontinuous=20function=5Fcall=20=E8=BF=9E=E7=BB=AD=E7=9A=84functi?= =?UTF-8?q?on=5Fcall=20=E8=BD=AC=E6=8D=A2=20tool=5Fcalls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../openai_openai-responses_request.go | 68 +++++++++++++++++-- .../openai_openai-responses_request_test.go | 37 ++++++++++ 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_request.go b/internal/translator/openai/openai/responses/openai_openai-responses_request.go index 9164a4116..15acf7cdb 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_request.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_request.go @@ -57,7 +57,24 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu // Convert input array to messages if input := root.Get("input"); input.Exists() && input.IsArray() { + inputItems := input.Array() + outputCallIDs := make(map[string]struct{}) + for _, item := range inputItems { + if item.Get("type").String() != "function_call_output" { + continue + } + callID := strings.TrimSpace(item.Get("call_id").String()) + if callID == "" { + continue + } + outputCallIDs[callID] = struct{}{} + } + pendingToolCalls := make([]interface{}, 0) + pendingToolCallIDs := make([]string, 0) + awaitingToolOutputs := make(map[string]struct{}) + deferredMessages := make([][]byte, 0) + flushPendingToolCalls := func() { if len(pendingToolCalls) == 0 { return @@ -65,10 +82,40 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu assistantMessage := []byte(`{"role":"assistant","tool_calls":[]}`) assistantMessage, _ = sjson.SetBytes(assistantMessage, "tool_calls", pendingToolCalls) out, _ = sjson.SetRawBytes(out, "messages.-1", assistantMessage) + for _, id := range pendingToolCallIDs { + if strings.TrimSpace(id) == "" { + continue + } + awaitingToolOutputs[id] = struct{}{} + } pendingToolCalls = pendingToolCalls[:0] + pendingToolCallIDs = pendingToolCallIDs[:0] + } + flushDeferredMessages := func() { + for _, message := range deferredMessages { + out, _ = sjson.SetRawBytes(out, "messages.-1", message) + } + deferredMessages = deferredMessages[:0] + } + hasAwaitingToolOutput := func() bool { + for id := range awaitingToolOutputs { + if _, ok := outputCallIDs[id]; ok { + return true + } + } + return false + } + appendRegularMessage := func(message []byte) { + // Keep tool-call adjacency strict for providers that require + // assistant(tool_calls) -> tool(tool_call_id) with no message in between. + if hasAwaitingToolOutput() { + deferredMessages = append(deferredMessages, message) + return + } + out, _ = sjson.SetRawBytes(out, "messages.-1", message) } - input.ForEach(func(_, item gjson.Result) bool { + for _, item := range inputItems { itemType := item.Get("type").String() if itemType == "" && item.Get("role").String() != "" { itemType = "message" @@ -123,7 +170,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu message, _ = sjson.SetBytes(message, "content", content.String()) } - out, _ = sjson.SetRawBytes(out, "messages.-1", message) + appendRegularMessage(message) case "function_call": // Buffer consecutive function calls and emit them as one assistant message. @@ -141,13 +188,18 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu toolCall, _ = sjson.SetBytes(toolCall, "function.arguments", arguments.String()) } pendingToolCalls = append(pendingToolCalls, gjson.ParseBytes(toolCall).Value()) + if callID := strings.TrimSpace(item.Get("call_id").String()); callID != "" { + pendingToolCallIDs = append(pendingToolCallIDs, callID) + } case "function_call_output": // Handle function call output conversion to tool message toolMessage := []byte(`{"role":"tool","tool_call_id":"","content":""}`) + callID := "" if callId := item.Get("call_id"); callId.Exists() { - toolMessage, _ = sjson.SetBytes(toolMessage, "tool_call_id", callId.String()) + callID = strings.TrimSpace(callId.String()) + toolMessage, _ = sjson.SetBytes(toolMessage, "tool_call_id", callID) } if output := item.Get("output"); output.Exists() { @@ -155,11 +207,17 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu } out, _ = sjson.SetRawBytes(out, "messages.-1", toolMessage) + if callID != "" { + delete(awaitingToolOutputs, callID) + } + if len(awaitingToolOutputs) == 0 && len(deferredMessages) > 0 { + flushDeferredMessages() + } } - return true - }) + } flushPendingToolCalls() + flushDeferredMessages() } else if input.Type == gjson.String { msg := []byte(`{}`) msg, _ = sjson.SetBytes(msg, "role", "user") diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go b/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go index e9339753a..9dd0e288b 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go @@ -85,3 +85,40 @@ func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_SplitFunctionCalls t.Fatalf("messages.2.tool_calls.0.id = %q, want %q", got, "call_b") } } + +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_DefersMessageUntilToolOutput(t *testing.T) { + raw := []byte(`{ + "input": [ + {"type":"function_call","call_id":"call_x","name":"exec_command","arguments":"{\"cmd\":\"echo hi\"}"}, + {"type":"message","role":"user","content":"Approved command prefix saved"}, + {"type":"function_call_output","call_id":"call_x","output":"ok"}, + {"type":"message","role":"user","content":"next"} + ] + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("kimi-k2.6", raw, true) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := len(gjson.GetBytes(out, "messages").Array()); got != 4 { + t.Fatalf("messages count = %d, want %d", got, 4) + } + if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" { + t.Fatalf("messages.0.role = %q, want %q", got, "assistant") + } + if got := gjson.GetBytes(out, "messages.1.role").String(); got != "tool" { + t.Fatalf("messages.1.role = %q, want %q", got, "tool") + } + if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_x" { + t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_x") + } + if got := gjson.GetBytes(out, "messages.2.role").String(); got != "user" { + t.Fatalf("messages.2.role = %q, want %q", got, "user") + } + if got := gjson.GetBytes(out, "messages.2.content").String(); got != "Approved command prefix saved" { + t.Fatalf("messages.2.content = %q, want %q", got, "Approved command prefix saved") + } + if got := gjson.GetBytes(out, "messages.3.content").String(); got != "next" { + t.Fatalf("messages.3.content = %q, want %q", got, "next") + } +}