diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go index 1fd3f2ae1..99c752387 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go @@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct { CreatedAt int64 ResponseID string FinishReason string + Usage claudeUsageTokens // Tool calls accumulator for streaming ToolCallsAccumulator map[int]*ToolCallAccumulator } +type claudeUsageTokens struct { + InputTokens int64 + OutputTokens int64 + CacheCreationInputTokens int64 + CacheReadInputTokens int64 + HasUsage bool +} + // ToolCallAccumulator holds the state for accumulating tool call data type ToolCallAccumulator struct { ID string @@ -36,15 +45,30 @@ type ToolCallAccumulator struct { Arguments strings.Builder } -func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) { - inputTokens := usage.Get("input_tokens").Int() - completionTokens = usage.Get("output_tokens").Int() - cachedTokens = usage.Get("cache_read_input_tokens").Int() - cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() +func (u *claudeUsageTokens) Merge(usage gjson.Result) { + if !usage.Exists() { + return + } + u.HasUsage = true + if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() { + u.InputTokens = inputTokens.Int() + } + if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() { + u.OutputTokens = outputTokens.Int() + } + if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() { + u.CacheCreationInputTokens = cacheCreationInputTokens.Int() + } + if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() { + u.CacheReadInputTokens = cacheReadInputTokens.Int() + } +} - promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens +func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) { + cachedTokens = u.CacheReadInputTokens + promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens + completionTokens = u.OutputTokens totalTokens = promptTokens + completionTokens - return promptTokens, completionTokens, totalTokens, cachedTokens } @@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil { (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator) } + (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage")) } return [][]byte{template} @@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original // Handle usage information for token counts if usage := root.Get("usage"); usage.Exists() { - promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage) + (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage) + promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage() template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens) template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens) template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens) @@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina var stopReason string var contentParts []string var reasoningParts []string + usageTokens := claudeUsageTokens{} toolCallsAccumulator := make(map[int]*ToolCallAccumulator) for _, chunk := range chunks { @@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina messageID = message.Get("id").String() model = message.Get("model").String() createdAt = time.Now().Unix() + usageTokens.Merge(message.Get("usage")) } case "content_block_start": @@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina } } if usage := root.Get("usage"); usage.Exists() { - promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage) - out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens) - out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens) - out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens) - out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens) + usageTokens.Merge(usage) } } } + if usageTokens.HasUsage { + promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage() + out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens) + out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens) + out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens) + out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens) + } + // Set basic response fields including message ID, creation time, and model out, _ = sjson.SetBytes(out, "id", messageID) out, _ = sjson.SetBytes(out, "created", createdAt) diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go index 7bd6eb1f1..5a9a6d3ad 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go @@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin } } +func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) { + ctx := context.Background() + var param any + + ConvertClaudeResponseToOpenAI( + ctx, + "claude-opus-4-6", + nil, + nil, + []byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`), + ¶m, + ) + out := ConvertClaudeResponseToOpenAI( + ctx, + "claude-opus-4-6", + nil, + nil, + []byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`), + ¶m, + ) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 { + t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens) + } + if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 { + t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens) + } + if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 { + t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens) + } + if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 { + t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens) + } +} + func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) { rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n") @@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens) } } + +func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) { + rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" + + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n") + + out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil) + + if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 { + t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens) + } + if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 { + t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens) + } + if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 { + t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens) + } + if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 { + t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens) + } +}