feat(translator): ensure correct finish_reason handling for all response chunks

- Added tests (`TestCliFinishReasonOnlyOnFinalChunk`, `TestGeminiFinishReasonOnlyOnFinalChunk`) to validate correct `finish_reason` and `native_finish_reason` assignment.
- Refactored Gemini and CLI translators to track `SawToolCall` and `UpstreamFinishReason` for accurate final-chunk determination.
- Improved response parsing logic to align with upstream metadata and provide consistent reasoning on chunk outputs.
This commit is contained in:
Luis Pater
2026-06-11 00:17:45 +08:00
parent 1ca048abdc
commit 58bf645e66
4 changed files with 139 additions and 45 deletions

View File

@@ -22,9 +22,11 @@ import (
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
type convertCliResponseToOpenAIChatParams struct {
UnixTimestamp int64
FunctionIndex int
SanitizedNameMap map[string]string
UnixTimestamp int64
FunctionIndex int
SawToolCall bool
UpstreamFinishReason string
SanitizedNameMap map[string]string
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
@@ -84,16 +86,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
}
finishReason := ""
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
finishReason = stopReasonResult.String()
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
}
if finishReason == "" {
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
finishReason = finishReasonResult.String()
}
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() && stopReasonResult.String() != "" {
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(stopReasonResult.String())
}
finishReason = strings.ToLower(finishReason)
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
@@ -122,7 +120,6 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
@@ -158,7 +155,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
@@ -205,15 +202,23 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
}
if hasFunctionCall {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
// Only pass through specific finish reasons
if finishReason == "max_tokens" || finishReason == "stop" {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
params := (*param).(*convertCliResponseToOpenAIChatParams)
upstreamFinishReason := params.UpstreamFinishReason
sawToolCall := params.SawToolCall
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
isFinalChunk := upstreamFinishReason != "" && usageExists
if isFinalChunk {
var finishReason string
if sawToolCall {
finishReason = "tool_calls"
} else if upstreamFinishReason == "MAX_TOKENS" {
finishReason = "max_tokens"
} else {
finishReason = "stop"
}
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
}
return [][]byte{template}

View File

@@ -0,0 +1,40 @@
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestCliFinishReasonOnlyOnFinalChunk(t *testing.T) {
ctx := context.Background()
var param any
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"C:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`)
result1 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param)
if len(result1) != 1 {
t.Fatalf("expected 1 result from chunk1, got %d", len(result1))
}
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Fatalf("expected null finish_reason on tool chunk, got %v", fr1.String())
}
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"D:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`)
ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
chunk3 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
result3 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk3, &param)
if len(result3) != 1 {
t.Fatalf("expected 1 result from chunk3, got %d", len(result3))
}
fr3 := gjson.GetBytes(result3[0], "choices.0.finish_reason").String()
if fr3 != "tool_calls" {
t.Fatalf("expected finish_reason tool_calls, got %s", fr3)
}
nfr3 := gjson.GetBytes(result3[0], "choices.0.native_finish_reason").String()
if nfr3 != "stop" {
t.Fatalf("expected native_finish_reason stop, got %s", nfr3)
}
}

View File

@@ -23,8 +23,10 @@ import (
type convertGeminiResponseToOpenAIChatParams struct {
UnixTimestamp int64
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
FunctionIndex map[int]int
SanitizedNameMap map[string]string
FunctionIndex map[int]int
SawToolCall map[int]bool
UpstreamFinishReason map[int]string
SanitizedNameMap map[string]string
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
@@ -48,9 +50,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Initialize parameters if nil.
if *param == nil {
*param = &convertGeminiResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: make(map[int]int),
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
UnixTimestamp: 0,
FunctionIndex: make(map[int]int),
SawToolCall: make(map[int]bool),
UpstreamFinishReason: make(map[int]string),
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
}
}
@@ -59,6 +63,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if p.FunctionIndex == nil {
p.FunctionIndex = make(map[int]int)
}
if p.SawToolCall == nil {
p.SawToolCall = make(map[int]bool)
}
if p.UpstreamFinishReason == nil {
p.UpstreamFinishReason = make(map[int]string)
}
if p.SanitizedNameMap == nil {
p.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
}
@@ -135,19 +145,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
candidateIndex := int(candidate.Get("index").Int())
template, _ = sjson.SetBytes(template, "choices.0.index", candidateIndex)
finishReason := ""
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
finishReason = stopReasonResult.String()
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
p.UpstreamFinishReason[candidateIndex] = strings.ToUpper(finishReasonResult.String())
}
if finishReason == "" {
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
finishReason = finishReasonResult.String()
}
}
finishReason = strings.ToLower(finishReason)
partsResult := candidate.Get("content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partResults := partsResult.Array()
@@ -183,7 +185,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
p.SawToolCall[candidateIndex] = true
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
// Retrieve the function index for this specific candidate.
@@ -233,15 +235,22 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
}
}
if hasFunctionCall {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
} else if finishReason != "" {
// Only pass through specific finish reasons
if finishReason == "max_tokens" || finishReason == "stop" {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
upstreamFinishReason := p.UpstreamFinishReason[candidateIndex]
sawToolCall := p.SawToolCall[candidateIndex]
usageExists := gjson.GetBytes(rawJSON, "usageMetadata").Exists()
isFinalChunk := upstreamFinishReason != "" && usageExists
if isFinalChunk {
var finishReason string
if sawToolCall {
finishReason = "tool_calls"
} else if upstreamFinishReason == "MAX_TOKENS" {
finishReason = "max_tokens"
} else {
finishReason = "stop"
}
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
}
responseStrings = append(responseStrings, template)

View File

@@ -0,0 +1,40 @@
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestGeminiFinishReasonOnlyOnFinalChunk(t *testing.T) {
ctx := context.Background()
var param any
chunk1 := []byte(`{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"C:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}`)
result1 := ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param)
if len(result1) != 1 {
t.Fatalf("expected 1 result from chunk1, got %d", len(result1))
}
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Fatalf("expected null finish_reason on tool chunk, got %v", fr1.String())
}
chunk2 := []byte(`{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"D:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}`)
ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
chunk3 := []byte(`{"candidates":[{"content":{"parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}`)
result3 := ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk3, &param)
if len(result3) != 1 {
t.Fatalf("expected 1 result from chunk3, got %d", len(result3))
}
fr3 := gjson.GetBytes(result3[0], "choices.0.finish_reason").String()
if fr3 != "tool_calls" {
t.Fatalf("expected finish_reason tool_calls, got %s", fr3)
}
nfr3 := gjson.GetBytes(result3[0], "choices.0.native_finish_reason").String()
if nfr3 != "stop" {
t.Fatalf("expected native_finish_reason stop, got %s", nfr3)
}
}