mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-20 03:17:20 +08:00
feat(translator): ensure correct finish_reason handling for all response chunks
- Added tests (`TestCliFinishReasonOnlyOnFinalChunk`, `TestGeminiFinishReasonOnlyOnFinalChunk`) to validate correct `finish_reason` and `native_finish_reason` assignment. - Refactored Gemini and CLI translators to track `SawToolCall` and `UpstreamFinishReason` for accurate final-chunk determination. - Improved response parsing logic to align with upstream metadata and provide consistent reasoning on chunk outputs.
This commit is contained in:
@@ -22,9 +22,11 @@ import (
|
||||
|
||||
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
|
||||
type convertCliResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
SanitizedNameMap map[string]string
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
SawToolCall bool
|
||||
UpstreamFinishReason string
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -84,16 +86,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() && stopReasonResult.String() != "" {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(stopReasonResult.String())
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
@@ -122,7 +120,6 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||
hasFunctionCall := false
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
@@ -158,7 +155,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true
|
||||
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||
@@ -205,15 +202,23 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
params := (*param).(*convertCliResponseToOpenAIChatParams)
|
||||
upstreamFinishReason := params.UpstreamFinishReason
|
||||
sawToolCall := params.SawToolCall
|
||||
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
|
||||
isFinalChunk := upstreamFinishReason != "" && usageExists
|
||||
|
||||
if isFinalChunk {
|
||||
var finishReason string
|
||||
if sawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
} else if upstreamFinishReason == "MAX_TOKENS" {
|
||||
finishReason = "max_tokens"
|
||||
} else {
|
||||
finishReason = "stop"
|
||||
}
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
}
|
||||
|
||||
return [][]byte{template}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCliFinishReasonOnlyOnFinalChunk(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"C:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`)
|
||||
result1 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
if len(result1) != 1 {
|
||||
t.Fatalf("expected 1 result from chunk1, got %d", len(result1))
|
||||
}
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Fatalf("expected null finish_reason on tool chunk, got %v", fr1.String())
|
||||
}
|
||||
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"D:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`)
|
||||
ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
chunk3 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
|
||||
result3 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk3, ¶m)
|
||||
if len(result3) != 1 {
|
||||
t.Fatalf("expected 1 result from chunk3, got %d", len(result3))
|
||||
}
|
||||
fr3 := gjson.GetBytes(result3[0], "choices.0.finish_reason").String()
|
||||
if fr3 != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason tool_calls, got %s", fr3)
|
||||
}
|
||||
nfr3 := gjson.GetBytes(result3[0], "choices.0.native_finish_reason").String()
|
||||
if nfr3 != "stop" {
|
||||
t.Fatalf("expected native_finish_reason stop, got %s", nfr3)
|
||||
}
|
||||
}
|
||||
@@ -23,8 +23,10 @@ import (
|
||||
type convertGeminiResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
|
||||
FunctionIndex map[int]int
|
||||
SanitizedNameMap map[string]string
|
||||
FunctionIndex map[int]int
|
||||
SawToolCall map[int]bool
|
||||
UpstreamFinishReason map[int]string
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -48,9 +50,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// Initialize parameters if nil.
|
||||
if *param == nil {
|
||||
*param = &convertGeminiResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
SawToolCall: make(map[int]bool),
|
||||
UpstreamFinishReason: make(map[int]string),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +63,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if p.FunctionIndex == nil {
|
||||
p.FunctionIndex = make(map[int]int)
|
||||
}
|
||||
if p.SawToolCall == nil {
|
||||
p.SawToolCall = make(map[int]bool)
|
||||
}
|
||||
if p.UpstreamFinishReason == nil {
|
||||
p.UpstreamFinishReason = make(map[int]string)
|
||||
}
|
||||
if p.SanitizedNameMap == nil {
|
||||
p.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
@@ -135,19 +145,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
candidateIndex := int(candidate.Get("index").Int())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.index", candidateIndex)
|
||||
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||
p.UpstreamFinishReason[candidateIndex] = strings.ToUpper(finishReasonResult.String())
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
partsResult := candidate.Get("content.parts")
|
||||
hasFunctionCall := false
|
||||
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
@@ -183,7 +185,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
p.SawToolCall[candidateIndex] = true
|
||||
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||
|
||||
// Retrieve the function index for this specific candidate.
|
||||
@@ -233,15 +235,22 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
upstreamFinishReason := p.UpstreamFinishReason[candidateIndex]
|
||||
sawToolCall := p.SawToolCall[candidateIndex]
|
||||
usageExists := gjson.GetBytes(rawJSON, "usageMetadata").Exists()
|
||||
isFinalChunk := upstreamFinishReason != "" && usageExists
|
||||
|
||||
if isFinalChunk {
|
||||
var finishReason string
|
||||
if sawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
} else if upstreamFinishReason == "MAX_TOKENS" {
|
||||
finishReason = "max_tokens"
|
||||
} else {
|
||||
finishReason = "stop"
|
||||
}
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
}
|
||||
|
||||
responseStrings = append(responseStrings, template)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGeminiFinishReasonOnlyOnFinalChunk(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
chunk1 := []byte(`{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"C:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}`)
|
||||
result1 := ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
if len(result1) != 1 {
|
||||
t.Fatalf("expected 1 result from chunk1, got %d", len(result1))
|
||||
}
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Fatalf("expected null finish_reason on tool chunk, got %v", fr1.String())
|
||||
}
|
||||
|
||||
chunk2 := []byte(`{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"D:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}`)
|
||||
ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
chunk3 := []byte(`{"candidates":[{"content":{"parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}`)
|
||||
result3 := ConvertGeminiResponseToOpenAI(ctx, "model", nil, nil, chunk3, ¶m)
|
||||
if len(result3) != 1 {
|
||||
t.Fatalf("expected 1 result from chunk3, got %d", len(result3))
|
||||
}
|
||||
fr3 := gjson.GetBytes(result3[0], "choices.0.finish_reason").String()
|
||||
if fr3 != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason tool_calls, got %s", fr3)
|
||||
}
|
||||
nfr3 := gjson.GetBytes(result3[0], "choices.0.native_finish_reason").String()
|
||||
if nfr3 != "stop" {
|
||||
t.Fatalf("expected native_finish_reason stop, got %s", nfr3)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user