Merge pull request #3595 from Progress-infinitely/fix/anthropic-tool-name-reverse-map

fix(amp): restore response tool casing from request
This commit is contained in:
Luis Pater
2026-05-29 02:17:27 +08:00
committed by GitHub
4 changed files with 192 additions and 6 deletions

View File

@@ -252,7 +252,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Log: Model was mapped to another model
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
rewriter := NewResponseRewriter(c.Writer, modelName)
rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes)
rewriter.suppressThinking = true
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
@@ -267,7 +267,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Wrap with ResponseRewriter for local providers too, because upstream
// proxies (e.g. NewAPI) may return a different model name and lack
// Amp-required fields like thinking.signature.
rewriter := NewResponseRewriter(c.Writer, modelName)
rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes)
rewriter.suppressThinking = providerName != "claude"
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths

View File

@@ -13,6 +13,38 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
)
func TestFallbackHandler_RequestToolCasing_RewritesStreamingResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-amp-tool-casing", "codex", []*registry.ModelInfo{
{ID: "test/gpt-tool-casing", OwnedBy: "openai", Type: "codex"},
})
defer reg.UnregisterClient("test-client-amp-tool-casing")
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, nil, nil)
handler := func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
_, _ = c.Writer.Write([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n"))
}
r := gin.New()
r.POST("/messages", fallback.WrapHandler(handler))
reqBody := []byte(`{"model":"test/gpt-tool-casing","tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
req := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
if !bytes.Contains(w.Body.Bytes(), []byte(`"name":"Glob"`)) {
t.Fatalf("expected streaming response to restore glob->Glob, got %s", w.Body.String())
}
}
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -22,6 +22,7 @@ type ResponseRewriter struct {
originalModel string
isStreaming bool
suppressThinking bool
requestToolNames map[string]string
}
// NewResponseRewriter creates a new response rewriter for model name substitution.
@@ -33,6 +34,12 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
}
}
func NewResponseRewriterForRequest(w gin.ResponseWriter, originalModel string, requestBody []byte) *ResponseRewriter {
rw := NewResponseRewriter(w, originalModel)
rw.requestToolNames = collectRequestToolNames(requestBody)
return rw
}
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool {
@@ -134,17 +141,70 @@ var ampCanonicalToolNames = map[string]string{
"check": "Check",
}
func collectRequestToolNames(data []byte) map[string]string {
if len(data) == 0 {
return nil
}
parsed := gjson.ParseBytes(data)
names := map[string]string{}
conflicts := map[string]bool{}
record := func(name string) {
if name == "" {
return
}
key := strings.ToLower(name)
if conflicts[key] {
return
}
if existing, exists := names[key]; exists {
if existing != name {
names[key] = ""
conflicts[key] = true
}
return
}
names[key] = name
}
for _, tool := range parsed.Get("tools").Array() {
record(tool.Get("name").String())
}
if parsed.Get("tool_choice.type").String() == "tool" {
record(parsed.Get("tool_choice.name").String())
}
if len(names) == 0 {
return nil
}
return names
}
func canonicalAmpToolName(name string, requestToolNames map[string]string) (string, bool) {
key := strings.ToLower(name)
if canonical, ok := requestToolNames[key]; ok {
if canonical == "" {
return "", false
}
return canonical, true
}
canonical, ok := ampCanonicalToolNames[key]
return canonical, ok
}
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
// which causes Amp's case-sensitive mode whitelist to reject them.
func normalizeAmpToolNames(data []byte) []byte {
return normalizeAmpToolNamesForRequest(data, nil)
}
func normalizeAmpToolNamesForRequest(data []byte, requestToolNames map[string]string) []byte {
// Non-streaming: content[].name in tool_use blocks
for index, block := range gjson.GetBytes(data, "content").Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical {
path := fmt.Sprintf("content.%d.name", index)
var err error
data, err = sjson.SetBytes(data, path, canonical)
@@ -157,7 +217,7 @@ func normalizeAmpToolNames(data []byte) []byte {
// Streaming: content_block.name in content_block_start events
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
name := gjson.GetBytes(data, "content_block.name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical {
var err error
data, err = sjson.SetBytes(data, "content_block.name", canonical)
if err != nil {
@@ -169,6 +229,10 @@ func normalizeAmpToolNames(data []byte) []byte {
return data
}
func (rw *ResponseRewriter) normalizeToolNames(data []byte) []byte {
return normalizeAmpToolNamesForRequest(data, rw.requestToolNames)
}
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
@@ -225,7 +289,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = normalizeAmpToolNames(data)
data = rw.normalizeToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
@@ -326,7 +390,7 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
data = ensureAmpSignature(data)
// Normalize tool names to canonical casing
data = normalizeAmpToolNames(data)
data = rw.normalizeToolNames(data)
// Rewrite model name
if rw.originalModel != "" {

View File

@@ -217,6 +217,96 @@ func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
}
}
func TestNormalizeAmpToolNames_RequestToolCasing_NonStreaming(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"})
if !contains(result, []byte(`"name":"Glob"`)) {
t.Errorf("expected glob->Glob when request advertised Glob, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_RequestToolCasing_Streaming(t *testing.T) {
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"glob","id":"toolu_01","input":{}}}`)
result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"})
if !contains(result, []byte(`"name":"Glob"`)) {
t.Errorf("expected glob->Glob in streaming when request advertised Glob, got %s", string(result))
}
}
func TestResponseRewriter_RequestToolCasingFromBody(t *testing.T) {
requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := rw.rewriteModelInResponse(input)
if !contains(result, []byte(`"name":"Glob"`)) {
t.Errorf("expected request body casing to restore glob->Glob, got %s", string(result))
}
}
func TestResponseRewriter_LowercaseNativeRequestPreserved(t *testing.T) {
requestBody := []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}}]}`)
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := rw.rewriteModelInResponse(input)
if string(result) == string(input) {
return
}
if !contains(result, []byte(`"name":"glob"`)) {
t.Errorf("expected lowercase-native request to preserve glob, got %s", string(result))
}
}
func TestCollectRequestToolNames_CollisionIgnored(t *testing.T) {
tests := []struct {
requestBody []byte
input []byte
forbidden []byte
}{
{
requestBody: []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}},{"name":"glob","input_schema":{"type":"object"}}]}`),
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`),
forbidden: []byte(`"name":"Glob"`),
},
{
requestBody: []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}},{"name":"Glob","input_schema":{"type":"object"}}]}`),
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`),
forbidden: []byte(`"name":"Glob"`),
},
{
requestBody: []byte(`{"tools":[{"name":"Bash","input_schema":{"type":"object"}},{"name":"bash","input_schema":{"type":"object"}}]}`),
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}}]}`),
forbidden: []byte(`"name":"Bash"`),
},
}
for _, tt := range tests {
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(tt.requestBody)}
result := rw.rewriteModelInResponse(tt.input)
if contains(result, tt.forbidden) {
t.Errorf("expected conflicting tool casing not to force %s, got %s", string(tt.forbidden), string(result))
}
}
}
func TestResponseRewriter_RequestToolCasingFromBody_Streaming(t *testing.T) {
requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
input := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n")
result := rw.rewriteStreamChunk(input)
if !contains(result, []byte(`"name":"Glob"`)) {
t.Errorf("expected streaming response to restore glob->Glob from request body, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
result := normalizeAmpToolNames(input)