From 01a7cc4a45880c9f49152131ebd529a099f3a294 Mon Sep 17 00:00:00 2001 From: Progress-infinitely <102594894+Progress-infinitely@users.noreply.github.com> Date: Thu, 28 May 2026 17:34:06 +0800 Subject: [PATCH] fix(amp): restore response tool casing from request --- internal/api/modules/amp/fallback_handlers.go | 4 +- .../api/modules/amp/fallback_handlers_test.go | 32 +++++++ internal/api/modules/amp/response_rewriter.go | 72 ++++++++++++++- .../api/modules/amp/response_rewriter_test.go | 90 +++++++++++++++++++ 4 files changed, 192 insertions(+), 6 deletions(-) diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 06e0a035d..4949ef7a4 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -252,7 +252,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Log: Model was mapped to another model log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) - rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes) rewriter.suppressThinking = true c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths @@ -267,7 +267,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Wrap with ResponseRewriter for local providers too, because upstream // proxies (e.g. NewAPI) may return a different model name and lack // Amp-required fields like thinking.signature. - rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes) rewriter.suppressThinking = providerName != "claude" c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go index 1aacaae21..7e6f10a2f 100644 --- a/internal/api/modules/amp/fallback_handlers_test.go +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -13,6 +13,38 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" ) +func TestFallbackHandler_RequestToolCasing_RewritesStreamingResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-amp-tool-casing", "codex", []*registry.ModelInfo{ + {ID: "test/gpt-tool-casing", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-amp-tool-casing") + + fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, nil, nil) + handler := func(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + _, _ = c.Writer.Write([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n")) + } + + r := gin.New() + r.POST("/messages", fallback.WrapHandler(handler)) + + reqBody := []byte(`{"model":"test/gpt-tool-casing","tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) + req := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + if !bytes.Contains(w.Body.Bytes(), []byte(`"name":"Glob"`)) { + t.Fatalf("expected streaming response to restore glob->Glob, got %s", w.Body.String()) + } +} + func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 895c494e7..86318119e 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -22,6 +22,7 @@ type ResponseRewriter struct { originalModel string isStreaming bool suppressThinking bool + requestToolNames map[string]string } // NewResponseRewriter creates a new response rewriter for model name substitution. @@ -33,6 +34,12 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe } } +func NewResponseRewriterForRequest(w gin.ResponseWriter, originalModel string, requestBody []byte) *ResponseRewriter { + rw := NewResponseRewriter(w, originalModel) + rw.requestToolNames = collectRequestToolNames(requestBody) + return rw +} + const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap func looksLikeSSEChunk(data []byte) bool { @@ -134,17 +141,70 @@ var ampCanonicalToolNames = map[string]string{ "check": "Check", } +func collectRequestToolNames(data []byte) map[string]string { + if len(data) == 0 { + return nil + } + parsed := gjson.ParseBytes(data) + names := map[string]string{} + conflicts := map[string]bool{} + record := func(name string) { + if name == "" { + return + } + key := strings.ToLower(name) + if conflicts[key] { + return + } + if existing, exists := names[key]; exists { + if existing != name { + names[key] = "" + conflicts[key] = true + } + return + } + names[key] = name + } + + for _, tool := range parsed.Get("tools").Array() { + record(tool.Get("name").String()) + } + if parsed.Get("tool_choice.type").String() == "tool" { + record(parsed.Get("tool_choice.name").String()) + } + if len(names) == 0 { + return nil + } + return names +} + +func canonicalAmpToolName(name string, requestToolNames map[string]string) (string, bool) { + key := strings.ToLower(name) + if canonical, ok := requestToolNames[key]; ok { + if canonical == "" { + return "", false + } + return canonical, true + } + canonical, ok := ampCanonicalToolNames[key] + return canonical, ok +} + // normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing. // Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash") // which causes Amp's case-sensitive mode whitelist to reject them. func normalizeAmpToolNames(data []byte) []byte { + return normalizeAmpToolNamesForRequest(data, nil) +} + +func normalizeAmpToolNamesForRequest(data []byte, requestToolNames map[string]string) []byte { // Non-streaming: content[].name in tool_use blocks for index, block := range gjson.GetBytes(data, "content").Array() { if block.Get("type").String() != "tool_use" { continue } name := block.Get("name").String() - if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical { path := fmt.Sprintf("content.%d.name", index) var err error data, err = sjson.SetBytes(data, path, canonical) @@ -157,7 +217,7 @@ func normalizeAmpToolNames(data []byte) []byte { // Streaming: content_block.name in content_block_start events if gjson.GetBytes(data, "content_block.type").String() == "tool_use" { name := gjson.GetBytes(data, "content_block.name").String() - if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical { var err error data, err = sjson.SetBytes(data, "content_block.name", canonical) if err != nil { @@ -169,6 +229,10 @@ func normalizeAmpToolNames(data []byte) []byte { return data } +func (rw *ResponseRewriter) normalizeToolNames(data []byte) []byte { + return normalizeAmpToolNamesForRequest(data, rw.requestToolNames) +} + // ensureAmpSignature injects empty signature fields into tool_use/thinking blocks // in API responses so that the Amp TUI does not crash on P.signature.length. func ensureAmpSignature(data []byte) []byte { @@ -225,7 +289,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { data = ensureAmpSignature(data) - data = normalizeAmpToolNames(data) + data = rw.normalizeToolNames(data) data = rw.suppressAmpThinking(data) if len(data) == 0 { return data @@ -326,7 +390,7 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { data = ensureAmpSignature(data) // Normalize tool names to canonical casing - data = normalizeAmpToolNames(data) + data = rw.normalizeToolNames(data) // Rewrite model name if rw.originalModel != "" { diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index a3a350cb2..609942edd 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -217,6 +217,96 @@ func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) { } } +func TestNormalizeAmpToolNames_RequestToolCasing_NonStreaming(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) + result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"}) + + if !contains(result, []byte(`"name":"Glob"`)) { + t.Errorf("expected glob->Glob when request advertised Glob, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_RequestToolCasing_Streaming(t *testing.T) { + input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"glob","id":"toolu_01","input":{}}}`) + result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"}) + + if !contains(result, []byte(`"name":"Glob"`)) { + t.Errorf("expected glob->Glob in streaming when request advertised Glob, got %s", string(result)) + } +} + +func TestResponseRewriter_RequestToolCasingFromBody(t *testing.T) { + requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) + rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) + + result := rw.rewriteModelInResponse(input) + + if !contains(result, []byte(`"name":"Glob"`)) { + t.Errorf("expected request body casing to restore glob->Glob, got %s", string(result)) + } +} + +func TestResponseRewriter_LowercaseNativeRequestPreserved(t *testing.T) { + requestBody := []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}}]}`) + rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) + + result := rw.rewriteModelInResponse(input) + + if string(result) == string(input) { + return + } + if !contains(result, []byte(`"name":"glob"`)) { + t.Errorf("expected lowercase-native request to preserve glob, got %s", string(result)) + } +} + +func TestCollectRequestToolNames_CollisionIgnored(t *testing.T) { + tests := []struct { + requestBody []byte + input []byte + forbidden []byte + }{ + { + requestBody: []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}},{"name":"glob","input_schema":{"type":"object"}}]}`), + input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`), + forbidden: []byte(`"name":"Glob"`), + }, + { + requestBody: []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}},{"name":"Glob","input_schema":{"type":"object"}}]}`), + input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`), + forbidden: []byte(`"name":"Glob"`), + }, + { + requestBody: []byte(`{"tools":[{"name":"Bash","input_schema":{"type":"object"}},{"name":"bash","input_schema":{"type":"object"}}]}`), + input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}}]}`), + forbidden: []byte(`"name":"Bash"`), + }, + } + + for _, tt := range tests { + rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(tt.requestBody)} + result := rw.rewriteModelInResponse(tt.input) + + if contains(result, tt.forbidden) { + t.Errorf("expected conflicting tool casing not to force %s, got %s", string(tt.forbidden), string(result)) + } + } +} + +func TestResponseRewriter_RequestToolCasingFromBody_Streaming(t *testing.T) { + requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) + rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} + input := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n") + + result := rw.rewriteStreamChunk(input) + + if !contains(result, []byte(`"name":"Glob"`)) { + t.Errorf("expected streaming response to restore glob->Glob from request body, got %s", string(result)) + } +} + func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) { input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`) result := normalizeAmpToolNames(input)