diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 707fe576b..895c494e7 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() { var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} +// ampCanonicalToolNames maps tool names to the exact casing expected by the +// Amp mode tool whitelist (case-sensitive match). +var ampCanonicalToolNames = map[string]string{ + "bash": "Bash", + "read": "Read", + "grep": "Grep", + "glob": "glob", + "task": "Task", + "check": "Check", +} + +// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing. +// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash") +// which causes Amp's case-sensitive mode whitelist to reject them. +func normalizeAmpToolNames(data []byte) []byte { + // Non-streaming: content[].name in tool_use blocks + for index, block := range gjson.GetBytes(data, "content").Array() { + if block.Get("type").String() != "tool_use" { + continue + } + name := block.Get("name").String() + if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + path := fmt.Sprintf("content.%d.name", index) + var err error + data, err = sjson.SetBytes(data, path, canonical) + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err) + } + } + } + + // Streaming: content_block.name in content_block_start events + if gjson.GetBytes(data, "content_block.type").String() == "tool_use" { + name := gjson.GetBytes(data, "content_block.name").String() + if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + var err error + data, err = sjson.SetBytes(data, "content_block.name", canonical) + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err) + } + } + } + + return data +} + // ensureAmpSignature injects empty signature fields into tool_use/thinking blocks // in API responses so that the Amp TUI does not crash on P.signature.length. func ensureAmpSignature(data []byte) []byte { @@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { data = ensureAmpSignature(data) + data = normalizeAmpToolNames(data) data = rw.suppressAmpThinking(data) if len(data) == 0 { return data @@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { // Inject empty signature where needed data = ensureAmpSignature(data) + // Normalize tool names to canonical casing + data = normalizeAmpToolNames(data) + // Rewrite model name if rw.originalModel != "" { for _, path := range modelFieldPaths { diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index ac95dfc64..a3a350cb2 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi } } +func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`) + result := normalizeAmpToolNames(input) + + if !contains(result, []byte(`"name":"Bash"`)) { + t.Errorf("expected bash->Bash, got %s", string(result)) + } + if !contains(result, []byte(`"name":"Read"`)) { + t.Errorf("expected read->Read, got %s", string(result)) + } + if contains(result, []byte(`"name":"bash"`)) { + t.Errorf("expected lowercase bash to be replaced, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_Streaming(t *testing.T) { + input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`) + result := normalizeAmpToolNames(input) + + if !contains(result, []byte(`"name":"Grep"`)) { + t.Errorf("expected grep->Grep in streaming, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected no modification for correctly-cased tool, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected glob to remain lowercase, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected no modification for unknown tool, got %s", string(result)) + } +} + func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) { diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index e3293d5bc..b22f4e448 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{ "notebookedit": "NotebookEdit", } -// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding. -var oauthToolRenameReverseMap = func() map[string]string { - m := make(map[string]string, len(oauthToolRenameMap)) - for k, v := range oauthToolRenameMap { - m[v] = k - } - return m -}() +// The reverse map is now computed per-request in remapOAuthToolNames so that +// only names the client actually caused us to rewrite are restored on the +// response. A global reverse map — as used previously — corrupted responses +// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase +// alongside `glob` lowercase; the request flagged renames via `glob→Glob`, +// then the global reverse map incorrectly rewrote every `Bash` in the +// response to `bash`, causing Amp to reject the tool_use as unknown). // oauthToolsToRemove lists tool names that must be stripped from OAuth requests // even after remapping. Currently empty — all tools are mapped instead of removed. @@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r bodyForTranslation := body bodyForUpstream := body oauthToken := isClaudeOAuthToken(apiKey) - oauthToolNamesRemapped := false - if oauthToken && !auth.ToolPrefixDisabled() { - bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap third-party tool names to Claude Code equivalents and remove - // tools without official counterparts. This prevents Anthropic from - // fingerprinting the request as third-party via tool naming patterns. + var oauthToolNamesReverseMap map[string]string if oauthToken { - bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream) + bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled()) } // Enable cch signing by default for OAuth tokens (not just experimental flag). // Claude Code always computes cch; missing or invalid cch is a detectable fingerprint. @@ -298,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } else { reporter.Publish(ctx, helps.ParseClaudeUsage(data)) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix) - } - // Reverse the OAuth tool name remap so the downstream client sees original names. - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - data = reverseRemapOAuthToolNames(data) - } + data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) var param any out := sdktranslator.TranslateNonStream( ctx, @@ -379,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A bodyForTranslation := body bodyForUpstream := body oauthToken := isClaudeOAuthToken(apiKey) - oauthToolNamesRemapped := false - if oauthToken && !auth.ToolPrefixDisabled() { - bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap third-party tool names to Claude Code equivalents and remove - // tools without official counterparts. This prevents Anthropic from - // fingerprinting the request as third-party via tool naming patterns. + var oauthToolNamesReverseMap map[string]string if oauthToken { - bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream) + bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled()) } // Enable cch signing by default for OAuth tokens (not just experimental flag). if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) { @@ -478,12 +459,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := helps.ParseClaudeStreamUsage(line); ok { reporter.Publish(ctx, detail) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) - } - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - line = reverseRemapOAuthToolNamesFromStreamLine(line) - } + line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) // Forward the line as-is to preserve SSE format cloned := make([]byte, len(line)+1) copy(cloned, line) @@ -515,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := helps.ParseClaudeStreamUsage(line); ok { reporter.Publish(ctx, detail) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) - } - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - line = reverseRemapOAuthToolNamesFromStreamLine(line) - } + line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) chunks := sdktranslator.TranslateStream( ctx, to, @@ -635,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut // Extract betas from body and convert to header (for count_tokens too) var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - body = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap tool names for OAuth token requests to avoid third-party fingerprinting. if isClaudeOAuthToken(apiKey) { - body, _ = remapOAuthToolNames(body) + body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled()) } url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL) @@ -1080,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool { return strings.Contains(apiKey, "sk-ant-oat") } +// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name +// transforms in the same order across request paths. Remap runs before prefixing +// so any future non-empty prefix still composes correctly with the per-request +// reverse map. +func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) { + body, reverseMap := remapOAuthToolNames(body) + if !prefixDisabled { + body = applyClaudeToolPrefix(body, prefix) + } + return body, reverseMap +} + +// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name +// transforms for non-stream responses in reverse order. +func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte { + if !prefixDisabled { + body = stripClaudeToolPrefixFromResponse(body, prefix) + } + return reverseRemapOAuthToolNames(body, reverseMap) +} + +// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name +// transforms for SSE lines in reverse order. +func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte { + if !prefixDisabled { + line = stripClaudeToolPrefixFromStreamLine(line, prefix) + } + return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap) +} + // remapOAuthToolNames renames third-party tool names to Claude Code equivalents // and removes tools without an official counterpart. This prevents Anthropic from // fingerprinting the request as a third-party client via tool naming patterns. @@ -1087,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool { // It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference // references in messages. Removed tools' corresponding tool_result blocks are preserved // (they just become orphaned, which is safe for Claude). -func remapOAuthToolNames(body []byte) ([]byte, bool) { - renamed := false +// +// The returned map is keyed on the upstream (TitleCase) name and maps to the +// client-supplied original name. Callers MUST pass this map to the reverse +// functions so only names the client actually caused us to rewrite are restored +// on the response. A global reverse map (the previous implementation) incorrectly +// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`) +// when any OTHER tool in the same request triggered a forward rename (e.g. +// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash` +// regardless of what the client originally sent. +func remapOAuthToolNames(body []byte) ([]byte, map[string]string) { + reverseMap := make(map[string]string, len(oauthToolRenameMap)) + recordRename := func(original, renamed string) { + // Preserve the first-seen original name if the same upstream name is + // produced from multiple call sites; they all map back identically. + if _, exists := reverseMap[renamed]; !exists { + reverseMap[renamed] = original + } + } + // 1. Rewrite tools array in a single pass (if present). // IMPORTANT: do not mutate names first and then rebuild from an older gjson // snapshot. gjson results are snapshots of the original bytes; rebuilding from a @@ -1121,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { updatedTool, err := sjson.Set(toolJSON, "name", newName) if err == nil { toolJSON = updatedTool - renamed = true + recordRename(name, newName) } } @@ -1146,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { body, _ = sjson.DeleteBytes(body, "tool_choice") } else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName { body, _ = sjson.SetBytes(body, "tool_choice.name", newName) - renamed = true + recordRename(tcName, newName) } } @@ -1166,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { if newName, ok := oauthToolRenameMap[name]; ok && newName != name { path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) body, _ = sjson.SetBytes(body, path, newName) - renamed = true + recordRename(name, newName) } case "tool_reference": toolName := part.Get("tool_name").String() if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName { path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int()) body, _ = sjson.SetBytes(body, path, newName) - renamed = true + recordRename(toolName, newName) } case "tool_result": // Handle nested tool_reference blocks inside tool_result.content[] @@ -1187,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName { nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int()) body, _ = sjson.SetBytes(body, nestedPath, newName) - renamed = true + recordRename(nestedToolName, newName) } } return true @@ -1200,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { }) } - return body, renamed + return body, reverseMap } -// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses. -// It maps Claude Code TitleCase names back to the original lowercase names so the -// downstream client receives tool names it recognizes. -func reverseRemapOAuthToolNames(body []byte) []byte { +// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses +// using the per-request map produced by remapOAuthToolNames. Names the client sent +// that were NOT forward-renamed are passed through unchanged. +func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte { + if len(reverseMap) == 0 { + return body + } content := gjson.GetBytes(body, "content") if !content.Exists() || !content.IsArray() { return body @@ -1216,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte { switch partType { case "tool_use": name := part.Get("name").String() - if origName, ok := oauthToolRenameReverseMap[name]; ok { + if origName, ok := reverseMap[name]; ok { path := fmt.Sprintf("content.%d.name", index.Int()) body, _ = sjson.SetBytes(body, path, origName) } case "tool_reference": toolName := part.Get("tool_name").String() - if origName, ok := oauthToolRenameReverseMap[toolName]; ok { + if origName, ok := reverseMap[toolName]; ok { path := fmt.Sprintf("content.%d.tool_name", index.Int()) body, _ = sjson.SetBytes(body, path, origName) } @@ -1232,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte { return body } -// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines. -func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { +// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE +// stream lines, using the per-request reverseMap produced by remapOAuthToolNames. +func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte { + if len(reverseMap) == 0 { + return line + } payload := helps.JSONPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return line @@ -1251,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { switch blockType { case "tool_use": name := contentBlock.Get("name").String() - if origName, ok := oauthToolRenameReverseMap[name]; ok { + if origName, ok := reverseMap[name]; ok { updated, err = sjson.SetBytes(payload, "content_block.name", origName) if err != nil { return line @@ -1261,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { } case "tool_reference": toolName := contentBlock.Get("tool_name").String() - if origName, ok := oauthToolRenameReverseMap[toolName]; ok { + if origName, ok := reverseMap[toolName]; ok { updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName) if err != nil { return line diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 6793adda4..2e9140440 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -2096,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) { body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) - out, renamed := remapOAuthToolNames(body) - if renamed { - t.Fatalf("renamed = true, want false") + out, reverseMap := remapOAuthToolNames(body) + if len(reverseMap) != 0 { + t.Fatalf("reverseMap = %v, want empty", reverseMap) } if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" { t.Fatalf("tools.0.name = %q, want %q", got, "Bash") } resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) - reversed := resp - if renamed { - reversed = reverseRemapOAuthToolNames(resp) - } + reversed := reverseRemapOAuthToolNames(resp, reverseMap) if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" { t.Fatalf("content.0.name = %q, want %q", got, "Bash") } @@ -2117,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) { func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) { body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) - out, renamed := remapOAuthToolNames(body) - if !renamed { - t.Fatalf("renamed = false, want true") + out, reverseMap := remapOAuthToolNames(body) + if reverseMap["Bash"] != "bash" { + t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap) } if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" { t.Fatalf("tools.0.name = %q, want %q", got, "Bash") } resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) - reversed := resp - if renamed { - reversed = reverseRemapOAuthToolNames(resp) - } + reversed := reverseRemapOAuthToolNames(resp, reverseMap) if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" { t.Fatalf("content.0.name = %q, want %q", got, "bash") } } + +// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression +// test for a case where a single request contains both a TitleCase tool (which +// must pass through unchanged) and a lowercase tool that we forward-rename. +// Before the fix, triggering ANY forward rename caused the reverse pass to +// lowercase every TitleCase tool in the response using a global reverse map, +// corrupting tool names the client originally sent in TitleCase (notably Amp +// CLI's `Bash`, which its registry lookup cannot find as `bash`). +func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) { + body := []byte(`{"tools":[` + + `{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` + + `{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` + + `]}`) + + out, reverseMap := remapOAuthToolNames(body) + + // Forward: TitleCase `Bash` is not a forward-map key, must pass through. + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" { + t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash") + } + // Forward: `glob` is a forward-map key, upstream sees `Glob`. + if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" { + t.Fatalf("tools.1.name = %q, want %q", got, "Glob") + } + + // Reverse map records ONLY the rename that happened. + if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" { + t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap) + } + + // Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`, + // reverseRemap MUST leave it alone. + bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) + reversed := reverseRemapOAuthToolNames(bashResp, reverseMap) + if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" { + t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash") + } + + // Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`, + // reverseRemap MUST restore the original `glob`. + globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`) + reversed = reverseRemapOAuthToolNames(globResp, reverseMap) + if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" { + t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob") + } +} + +// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the +// SSE streaming code path against the same mixed-case bug. +func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) { + reverseMap := map[string]string{"Glob": "glob"} + + // Bash block was never renamed, must pass through as-is. + bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`) + out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap) + if !bytes.Contains(out, []byte(`"name":"Bash"`)) { + t.Fatalf("Bash should be preserved, got: %s", string(out)) + } + if bytes.Contains(out, []byte(`"name":"bash"`)) { + t.Fatalf("Bash must not be lowercased, got: %s", string(out)) + } + + // Glob block IS in the reverseMap, must be restored to `glob`. + globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`) + out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap) + if !bytes.Contains(out, []byte(`"name":"glob"`)) { + t.Fatalf("Glob should be restored to glob, got: %s", string(out)) + } +} + +func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) { + body := []byte(`{"tools":[` + + `{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` + + `{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` + + `],"messages":[{"role":"assistant","content":[` + + `{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` + + `{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` + + `]}]}`) + + out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false) + + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" { + t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash") + } + if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" { + t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob") + } + if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" { + t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash") + } + if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" { + t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob") + } + if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" { + t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap) + } +} + +func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) { + reverseMap := map[string]string{"Glob": "glob"} + resp := []byte(`{"content":[` + + `{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` + + `{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` + + `]}`) + + out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap) + + if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" { + t.Fatalf("content.0.name = %q, want %q", got, "Bash") + } + if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" { + t.Fatalf("content.1.name = %q, want %q", got, "glob") + } +} + +func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) { + reverseMap := map[string]string{"Glob": "glob"} + + bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`) + out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap) + if !bytes.Contains(out, []byte(`"name":"Bash"`)) { + t.Fatalf("Bash should be preserved, got: %s", string(out)) + } + if bytes.Contains(out, []byte(`"name":"bash"`)) { + t.Fatalf("Bash must not be lowercased, got: %s", string(out)) + } + + globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`) + out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap) + if !bytes.Contains(out, []byte(`"name":"glob"`)) { + t.Fatalf("Glob should be restored to glob, got: %s", string(out)) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 8969ce2f6..8dd1a0a7b 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net/http" + "sort" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -45,7 +46,10 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) { } type responsesSSEFramer struct { - pending []byte + pending []byte + outputItems map[int][]byte + outputOrder []int + unindexedOutputItems [][]byte } func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) { @@ -61,7 +65,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) { if frameLen == 0 { break } - writeResponsesSSEChunk(w, f.pending[:frameLen]) + f.writeFrame(w, f.pending[:frameLen]) copy(f.pending, f.pending[frameLen:]) f.pending = f.pending[:len(f.pending)-frameLen] } @@ -72,7 +76,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) { if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) { return } - writeResponsesSSEChunk(w, f.pending) + f.writeFrame(w, f.pending) f.pending = f.pending[:0] } @@ -88,10 +92,133 @@ func (f *responsesSSEFramer) Flush(w io.Writer) { f.pending = f.pending[:0] return } - writeResponsesSSEChunk(w, f.pending) + f.writeFrame(w, f.pending) f.pending = f.pending[:0] } +func (f *responsesSSEFramer) writeFrame(w io.Writer, frame []byte) { + writeResponsesSSEChunk(w, f.repairFrame(frame)) +} + +func (f *responsesSSEFramer) repairFrame(frame []byte) []byte { + payload, ok := responsesSSEDataPayload(frame) + if !ok || len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) { + return frame + } + + switch gjson.GetBytes(payload, "type").String() { + case "response.output_item.done": + f.recordOutputItem(payload) + case "response.completed": + repaired := f.repairCompletedPayload(payload) + if !bytes.Equal(repaired, payload) { + return responsesSSEFrameWithData(frame, repaired) + } + } + return frame +} + +func responsesSSEDataPayload(frame []byte) ([]byte, bool) { + var payload []byte + found := false + for _, line := range bytes.Split(frame, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + trimmed := bytes.TrimSpace(line) + if !bytes.HasPrefix(trimmed, []byte("data:")) { + continue + } + data := bytes.TrimSpace(trimmed[len("data:"):]) + if found { + payload = append(payload, '\n') + } + payload = append(payload, data...) + found = true + } + return payload, found +} + +func responsesSSEFrameWithData(frame, payload []byte) []byte { + var out bytes.Buffer + for _, line := range bytes.Split(frame, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + trimmed := bytes.TrimSpace(line) + if len(trimmed) == 0 || bytes.HasPrefix(trimmed, []byte("data:")) { + continue + } + out.Write(line) + out.WriteByte('\n') + } + for _, line := range bytes.Split(payload, []byte("\n")) { + out.WriteString("data: ") + out.Write(line) + out.WriteByte('\n') + } + out.WriteByte('\n') + return out.Bytes() +} + +func (f *responsesSSEFramer) recordOutputItem(payload []byte) { + item := gjson.GetBytes(payload, "item") + if !item.Exists() || !item.IsObject() || item.Get("type").String() == "" { + return + } + + if outputIndex := gjson.GetBytes(payload, "output_index"); outputIndex.Exists() { + index := int(outputIndex.Int()) + if f.outputItems == nil { + f.outputItems = make(map[int][]byte) + } + if _, exists := f.outputItems[index]; !exists { + f.outputOrder = append(f.outputOrder, index) + } + f.outputItems[index] = append([]byte(nil), item.Raw...) + return + } + + f.unindexedOutputItems = append(f.unindexedOutputItems, append([]byte(nil), item.Raw...)) +} + +func (f *responsesSSEFramer) repairCompletedPayload(payload []byte) []byte { + if len(f.outputOrder) == 0 && len(f.unindexedOutputItems) == 0 { + return payload + } + output := gjson.GetBytes(payload, "response.output") + if output.Exists() && (!output.IsArray() || len(output.Array()) > 0) { + return payload + } + + var outputJSON bytes.Buffer + outputJSON.WriteByte('[') + indexes := append([]int(nil), f.outputOrder...) + sort.Ints(indexes) + written := 0 + for _, index := range indexes { + item, ok := f.outputItems[index] + if !ok { + continue + } + if written > 0 { + outputJSON.WriteByte(',') + } + outputJSON.Write(item) + written++ + } + for _, item := range f.unindexedOutputItems { + if written > 0 { + outputJSON.WriteByte(',') + } + outputJSON.Write(item) + written++ + } + outputJSON.WriteByte(']') + + repaired, err := sjson.SetRawBytes(payload, "response.output", outputJSON.Bytes()) + if err != nil { + return payload + } + return repaired +} + func responsesSSEFrameLen(chunk []byte) int { if len(chunk) == 0 { return 0 diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go index ef16fe80a..151da9a79 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go +++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go @@ -10,6 +10,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/tidwall/gjson" ) func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) { @@ -53,12 +54,108 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) { t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1) } - expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}" + expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}" if parts[1] != expectedPart2 { t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2) } } +func TestForwardResponsesStreamRepairsEmptyCompletedOutputFromDoneItems(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 3) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte(`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"reasoning","id":"rs-1","summary":[]}}`) + data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{\"cmd\":\"pwd\"}","status":"completed"}}`) + data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`) + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil) + + parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n") + if len(parts) != 3 { + t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String()) + } + + payload := strings.TrimPrefix(parts[2], "data: ") + output := gjson.Get(payload, "response.output") + if !output.IsArray() || len(output.Array()) != 2 { + t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw) + } + if got := gjson.Get(payload, "response.output.1.name").String(); got != "shell" { + t.Fatalf("expected function_call name to be preserved, got %q in %s", got, payload) + } + if got := gjson.Get(payload, "response.output.1.arguments").String(); got != `{"cmd":"pwd"}` { + t.Fatalf("expected function_call arguments to be preserved, got %q in %s", got, payload) + } +} + +func TestForwardResponsesStreamRepairsMixedIndexedAndUnindexedDoneItems(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 3) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{}","status":"completed"}}`) + data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"message","id":"msg-1","role":"assistant","content":[{"type":"output_text","text":"done"}]}}`) + data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`) + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil) + + parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n") + if len(parts) != 3 { + t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String()) + } + + payload := strings.TrimPrefix(parts[2], "data: ") + output := gjson.Get(payload, "response.output") + if !output.IsArray() || len(output.Array()) != 2 { + t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw) + } + if got := gjson.Get(payload, "response.output.0.name").String(); got != "shell" { + t.Fatalf("expected indexed function_call to be preserved first, got %q in %s", got, payload) + } + if got := gjson.Get(payload, "response.output.1.id").String(); got != "msg-1" { + t.Fatalf("expected unindexed message to be appended, got %q in %s", got, payload) + } +} + +func TestForwardResponsesStreamRepairsMultilineCompletedOutputAsSSEDataLines(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 2) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","arguments":"{}"}}`) + data <- []byte("data: {\"type\":\"response.completed\",\ndata: \"response\":{\"id\":\"resp-1\",\"output\":[]}}\n\n") + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil) + + parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n") + if len(parts) != 2 { + t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), recorder.Body.String()) + } + + completedFrame := []byte(parts[1]) + for _, line := range strings.Split(parts[1], "\n") { + if line != "" && !strings.HasPrefix(line, "data: ") { + t.Fatalf("expected every completed payload line to be an SSE data line, got %q in %q", line, parts[1]) + } + } + + payload, ok := responsesSSEDataPayload(completedFrame) + if !ok { + t.Fatalf("expected completed frame to contain data payload: %q", parts[1]) + } + output := gjson.GetBytes(payload, "response.output") + if !output.IsArray() || len(output.Array()) != 1 { + t.Fatalf("expected repaired completed output with 1 item, got %s from %q", output.Raw, payload) + } +} + func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) { h, recorder, c, flusher := newResponsesStreamTestHandler(t)