feat(translator): add token usage tracking and improve usage handling

- Introduced `claudeUsageTokens` struct for detailed token usage tracking.
- Replaced `calculateClaudeUsageTokens` with `Merge` and `OpenAIUsage` methods for better modularity.
- Enhanced integration of usage tokens into response processing, enabling more accurate reporting of token details.

Fixed: #2419
This commit is contained in:
Luis Pater
2026-05-04 16:57:50 +08:00
parent 89d80bfff4
commit 85c0150653
2 changed files with 103 additions and 13 deletions

View File

@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
CreatedAt int64
ResponseID string
FinishReason string
Usage claudeUsageTokens
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
}
type claudeUsageTokens struct {
InputTokens int64
OutputTokens int64
CacheCreationInputTokens int64
CacheReadInputTokens int64
HasUsage bool
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
Arguments strings.Builder
}
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
inputTokens := usage.Get("input_tokens").Int()
completionTokens = usage.Get("output_tokens").Int()
cachedTokens = usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
func (u *claudeUsageTokens) Merge(usage gjson.Result) {
if !usage.Exists() {
return
}
u.HasUsage = true
if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
u.InputTokens = inputTokens.Int()
}
if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
u.OutputTokens = outputTokens.Int()
}
if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
}
if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
u.CacheReadInputTokens = cacheReadInputTokens.Int()
}
}
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
cachedTokens = u.CacheReadInputTokens
promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
completionTokens = u.OutputTokens
totalTokens = promptTokens + completionTokens
return promptTokens, completionTokens, totalTokens, cachedTokens
}
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
}
return [][]byte{template}
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var stopReason string
var contentParts []string
var reasoningParts []string
usageTokens := claudeUsageTokens{}
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
usageTokens.Merge(message.Get("usage"))
}
case "content_block_start":
@@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
usageTokens.Merge(usage)
}
}
}
if usageTokens.HasUsage {
promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.SetBytes(out, "created", createdAt)

View File

@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
}
}
func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
ctx := context.Background()
var param any
ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
&param,
)
out := ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
&param,
)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}