Merge pull request #2896 from edlsh/fix/oauth-tool-rename-per-request-map

fix(amp): smart-mode tool name fixes + deep-mode response repair
This commit is contained in:
Luis Pater
2026-05-05 00:58:39 +08:00
committed by GitHub
6 changed files with 558 additions and 85 deletions

View File

@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// ampCanonicalToolNames maps tool names to the exact casing expected by the
// Amp mode tool whitelist (case-sensitive match).
var ampCanonicalToolNames = map[string]string{
"bash": "Bash",
"read": "Read",
"grep": "Grep",
"glob": "glob",
"task": "Task",
"check": "Check",
}
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
// which causes Amp's case-sensitive mode whitelist to reject them.
func normalizeAmpToolNames(data []byte) []byte {
// Non-streaming: content[].name in tool_use blocks
for index, block := range gjson.GetBytes(data, "content").Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
path := fmt.Sprintf("content.%d.name", index)
var err error
data, err = sjson.SetBytes(data, path, canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
}
}
}
// Streaming: content_block.name in content_block_start events
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
name := gjson.GetBytes(data, "content_block.name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
var err error
data, err = sjson.SetBytes(data, "content_block.name", canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
}
}
}
return data
}
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = normalizeAmpToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Normalize tool names to canonical casing
data = normalizeAmpToolNames(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {

View File

@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
}
}
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Bash"`)) {
t.Errorf("expected bash->Bash, got %s", string(result))
}
if !contains(result, []byte(`"name":"Read"`)) {
t.Errorf("expected read->Read, got %s", string(result))
}
if contains(result, []byte(`"name":"bash"`)) {
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Grep"`)) {
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected glob to remain lowercase, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for unknown tool, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
"notebookedit": "NotebookEdit",
}
// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
var oauthToolRenameReverseMap = func() map[string]string {
m := make(map[string]string, len(oauthToolRenameMap))
for k, v := range oauthToolRenameMap {
m[v] = k
}
return m
}()
// The reverse map is now computed per-request in remapOAuthToolNames so that
// only names the client actually caused us to rewrite are restored on the
// response. A global reverse map — as used previously — corrupted responses
// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
// then the global reverse map incorrectly rewrote every `Bash` in the
// response to `bash`, causing Amp to reject the tool_use as unknown).
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
// even after remapping. Currently empty — all tools are mapped instead of removed.
@@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
@@ -298,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
// Reverse the OAuth tool name remap so the downstream client sees original names.
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
data = reverseRemapOAuthToolNames(data)
}
data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
var param any
out := sdktranslator.TranslateNonStream(
ctx,
@@ -379,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
@@ -478,12 +459,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
@@ -515,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
chunks := sdktranslator.TranslateStream(
ctx,
to,
@@ -635,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap tool names for OAuth token requests to avoid third-party fingerprinting.
if isClaudeOAuthToken(apiKey) {
body, _ = remapOAuthToolNames(body)
body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled())
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
@@ -1080,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name
// transforms in the same order across request paths. Remap runs before prefixing
// so any future non-empty prefix still composes correctly with the per-request
// reverse map.
func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) {
body, reverseMap := remapOAuthToolNames(body)
if !prefixDisabled {
body = applyClaudeToolPrefix(body, prefix)
}
return body, reverseMap
}
// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name
// transforms for non-stream responses in reverse order.
func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
body = stripClaudeToolPrefixFromResponse(body, prefix)
}
return reverseRemapOAuthToolNames(body, reverseMap)
}
// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name
// transforms for SSE lines in reverse order.
func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
line = stripClaudeToolPrefixFromStreamLine(line, prefix)
}
return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap)
}
// remapOAuthToolNames renames third-party tool names to Claude Code equivalents
// and removes tools without an official counterpart. This prevents Anthropic from
// fingerprinting the request as a third-party client via tool naming patterns.
@@ -1087,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool {
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
// references in messages. Removed tools' corresponding tool_result blocks are preserved
// (they just become orphaned, which is safe for Claude).
func remapOAuthToolNames(body []byte) ([]byte, bool) {
renamed := false
//
// The returned map is keyed on the upstream (TitleCase) name and maps to the
// client-supplied original name. Callers MUST pass this map to the reverse
// functions so only names the client actually caused us to rewrite are restored
// on the response. A global reverse map (the previous implementation) incorrectly
// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
// when any OTHER tool in the same request triggered a forward rename (e.g.
// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
// regardless of what the client originally sent.
func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
reverseMap := make(map[string]string, len(oauthToolRenameMap))
recordRename := func(original, renamed string) {
// Preserve the first-seen original name if the same upstream name is
// produced from multiple call sites; they all map back identically.
if _, exists := reverseMap[renamed]; !exists {
reverseMap[renamed] = original
}
}
// 1. Rewrite tools array in a single pass (if present).
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
@@ -1121,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
updatedTool, err := sjson.Set(toolJSON, "name", newName)
if err == nil {
toolJSON = updatedTool
renamed = true
recordRename(name, newName)
}
}
@@ -1146,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
body, _ = sjson.DeleteBytes(body, "tool_choice")
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
renamed = true
recordRename(tcName, newName)
}
}
@@ -1166,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(name, newName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(toolName, newName)
}
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
@@ -1187,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, newName)
renamed = true
recordRename(nestedToolName, newName)
}
}
return true
@@ -1200,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
})
}
return body, renamed
return body, reverseMap
}
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
// It maps Claude Code TitleCase names back to the original lowercase names so the
// downstream client receives tool names it recognizes.
func reverseRemapOAuthToolNames(body []byte) []byte {
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
// using the per-request map produced by remapOAuthToolNames. Names the client sent
// that were NOT forward-renamed are passed through unchanged.
func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
@@ -1216,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
switch partType {
case "tool_use":
name := part.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
@@ -1232,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
return body
}
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return line
}
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
@@ -1251,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
if err != nil {
return line
@@ -1261,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
if err != nil {
return line

View File

@@ -2096,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if renamed {
t.Fatalf("renamed = true, want false")
out, reverseMap := remapOAuthToolNames(body)
if len(reverseMap) != 0 {
t.Fatalf("reverseMap = %v, want empty", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
@@ -2117,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if !renamed {
t.Fatalf("renamed = false, want true")
out, reverseMap := remapOAuthToolNames(body)
if reverseMap["Bash"] != "bash" {
t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
t.Fatalf("content.0.name = %q, want %q", got, "bash")
}
}
// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
// test for a case where a single request contains both a TitleCase tool (which
// must pass through unchanged) and a lowercase tool that we forward-rename.
// Before the fix, triggering ANY forward rename caused the reverse pass to
// lowercase every TitleCase tool in the response using a global reverse map,
// corrupting tool names the client originally sent in TitleCase (notably Amp
// CLI's `Bash`, which its registry lookup cannot find as `bash`).
func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`]}`)
out, reverseMap := remapOAuthToolNames(body)
// Forward: TitleCase `Bash` is not a forward-map key, must pass through.
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
}
// Forward: `glob` is a forward-map key, upstream sees `Glob`.
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
}
// Reverse map records ONLY the rename that happened.
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
// Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
// reverseRemap MUST leave it alone.
bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
}
// Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
// reverseRemap MUST restore the original `glob`.
globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
}
}
// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
// SSE streaming code path against the same mixed-case bug.
func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
// Bash block was never renamed, must pass through as-is.
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
// Glob block IS in the reverseMap, must be restored to `glob`.
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}
func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`],"messages":[{"role":"assistant","content":[` +
`{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` +
`]}]}`)
out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false)
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob")
}
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
}
func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
resp := []byte(`{"content":[` +
`{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` +
`]}`)
out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap)
if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" {
t.Fatalf("content.1.name = %q, want %q", got, "glob")
}
}
func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`)
out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`)
out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}

View File

@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -45,7 +46,10 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
}
type responsesSSEFramer struct {
pending []byte
pending []byte
outputItems map[int][]byte
outputOrder []int
unindexedOutputItems [][]byte
}
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
@@ -61,7 +65,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if frameLen == 0 {
break
}
writeResponsesSSEChunk(w, f.pending[:frameLen])
f.writeFrame(w, f.pending[:frameLen])
copy(f.pending, f.pending[frameLen:])
f.pending = f.pending[:len(f.pending)-frameLen]
}
@@ -72,7 +76,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
return
}
writeResponsesSSEChunk(w, f.pending)
f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
@@ -88,10 +92,133 @@ func (f *responsesSSEFramer) Flush(w io.Writer) {
f.pending = f.pending[:0]
return
}
writeResponsesSSEChunk(w, f.pending)
f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
func (f *responsesSSEFramer) writeFrame(w io.Writer, frame []byte) {
writeResponsesSSEChunk(w, f.repairFrame(frame))
}
func (f *responsesSSEFramer) repairFrame(frame []byte) []byte {
payload, ok := responsesSSEDataPayload(frame)
if !ok || len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) {
return frame
}
switch gjson.GetBytes(payload, "type").String() {
case "response.output_item.done":
f.recordOutputItem(payload)
case "response.completed":
repaired := f.repairCompletedPayload(payload)
if !bytes.Equal(repaired, payload) {
return responsesSSEFrameWithData(frame, repaired)
}
}
return frame
}
func responsesSSEDataPayload(frame []byte) ([]byte, bool) {
var payload []byte
found := false
for _, line := range bytes.Split(frame, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
trimmed := bytes.TrimSpace(line)
if !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
data := bytes.TrimSpace(trimmed[len("data:"):])
if found {
payload = append(payload, '\n')
}
payload = append(payload, data...)
found = true
}
return payload, found
}
func responsesSSEFrameWithData(frame, payload []byte) []byte {
var out bytes.Buffer
for _, line := range bytes.Split(frame, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
out.Write(line)
out.WriteByte('\n')
}
for _, line := range bytes.Split(payload, []byte("\n")) {
out.WriteString("data: ")
out.Write(line)
out.WriteByte('\n')
}
out.WriteByte('\n')
return out.Bytes()
}
func (f *responsesSSEFramer) recordOutputItem(payload []byte) {
item := gjson.GetBytes(payload, "item")
if !item.Exists() || !item.IsObject() || item.Get("type").String() == "" {
return
}
if outputIndex := gjson.GetBytes(payload, "output_index"); outputIndex.Exists() {
index := int(outputIndex.Int())
if f.outputItems == nil {
f.outputItems = make(map[int][]byte)
}
if _, exists := f.outputItems[index]; !exists {
f.outputOrder = append(f.outputOrder, index)
}
f.outputItems[index] = append([]byte(nil), item.Raw...)
return
}
f.unindexedOutputItems = append(f.unindexedOutputItems, append([]byte(nil), item.Raw...))
}
func (f *responsesSSEFramer) repairCompletedPayload(payload []byte) []byte {
if len(f.outputOrder) == 0 && len(f.unindexedOutputItems) == 0 {
return payload
}
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && (!output.IsArray() || len(output.Array()) > 0) {
return payload
}
var outputJSON bytes.Buffer
outputJSON.WriteByte('[')
indexes := append([]int(nil), f.outputOrder...)
sort.Ints(indexes)
written := 0
for _, index := range indexes {
item, ok := f.outputItems[index]
if !ok {
continue
}
if written > 0 {
outputJSON.WriteByte(',')
}
outputJSON.Write(item)
written++
}
for _, item := range f.unindexedOutputItems {
if written > 0 {
outputJSON.WriteByte(',')
}
outputJSON.Write(item)
written++
}
outputJSON.WriteByte(']')
repaired, err := sjson.SetRawBytes(payload, "response.output", outputJSON.Bytes())
if err != nil {
return payload
}
return repaired
}
func responsesSSEFrameLen(chunk []byte) int {
if len(chunk) == 0 {
return 0

View File

@@ -10,6 +10,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
@@ -53,12 +54,108 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
}
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}"
if parts[1] != expectedPart2 {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}
func TestForwardResponsesStreamRepairsEmptyCompletedOutputFromDoneItems(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 3)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"reasoning","id":"rs-1","summary":[]}}`)
data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{\"cmd\":\"pwd\"}","status":"completed"}}`)
data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 3 {
t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
payload := strings.TrimPrefix(parts[2], "data: ")
output := gjson.Get(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 2 {
t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
}
if got := gjson.Get(payload, "response.output.1.name").String(); got != "shell" {
t.Fatalf("expected function_call name to be preserved, got %q in %s", got, payload)
}
if got := gjson.Get(payload, "response.output.1.arguments").String(); got != `{"cmd":"pwd"}` {
t.Fatalf("expected function_call arguments to be preserved, got %q in %s", got, payload)
}
}
func TestForwardResponsesStreamRepairsMixedIndexedAndUnindexedDoneItems(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 3)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{}","status":"completed"}}`)
data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"message","id":"msg-1","role":"assistant","content":[{"type":"output_text","text":"done"}]}}`)
data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 3 {
t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
payload := strings.TrimPrefix(parts[2], "data: ")
output := gjson.Get(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 2 {
t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
}
if got := gjson.Get(payload, "response.output.0.name").String(); got != "shell" {
t.Fatalf("expected indexed function_call to be preserved first, got %q in %s", got, payload)
}
if got := gjson.Get(payload, "response.output.1.id").String(); got != "msg-1" {
t.Fatalf("expected unindexed message to be appended, got %q in %s", got, payload)
}
}
func TestForwardResponsesStreamRepairsMultilineCompletedOutputAsSSEDataLines(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","arguments":"{}"}}`)
data <- []byte("data: {\"type\":\"response.completed\",\ndata: \"response\":{\"id\":\"resp-1\",\"output\":[]}}\n\n")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 2 {
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
completedFrame := []byte(parts[1])
for _, line := range strings.Split(parts[1], "\n") {
if line != "" && !strings.HasPrefix(line, "data: ") {
t.Fatalf("expected every completed payload line to be an SSE data line, got %q in %q", line, parts[1])
}
}
payload, ok := responsesSSEDataPayload(completedFrame)
if !ok {
t.Fatalf("expected completed frame to contain data payload: %q", parts[1])
}
output := gjson.GetBytes(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 1 {
t.Fatalf("expected repaired completed output with 1 item, got %s from %q", output.Raw, payload)
}
}
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)