mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-01 04:12:28 +08:00
Merge pull request #3595 from Progress-infinitely/fix/anthropic-tool-name-reverse-map
fix(amp): restore response tool casing from request
This commit is contained in:
@@ -252,7 +252,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes)
|
||||
rewriter.suppressThinking = true
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
@@ -267,7 +267,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||
// Amp-required fields like thinking.signature.
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes)
|
||||
rewriter.suppressThinking = providerName != "claude"
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
|
||||
@@ -13,6 +13,38 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_RequestToolCasing_RewritesStreamingResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-tool-casing", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-tool-casing", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-tool-casing")
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, nil, nil)
|
||||
handler := func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = c.Writer.Write([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n"))
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/messages", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"test/gpt-tool-casing","tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
if !bytes.Contains(w.Body.Bytes(), []byte(`"name":"Glob"`)) {
|
||||
t.Fatalf("expected streaming response to restore glob->Glob, got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ type ResponseRewriter struct {
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
suppressThinking bool
|
||||
requestToolNames map[string]string
|
||||
}
|
||||
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||
@@ -33,6 +34,12 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
|
||||
}
|
||||
}
|
||||
|
||||
func NewResponseRewriterForRequest(w gin.ResponseWriter, originalModel string, requestBody []byte) *ResponseRewriter {
|
||||
rw := NewResponseRewriter(w, originalModel)
|
||||
rw.requestToolNames = collectRequestToolNames(requestBody)
|
||||
return rw
|
||||
}
|
||||
|
||||
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
@@ -134,17 +141,70 @@ var ampCanonicalToolNames = map[string]string{
|
||||
"check": "Check",
|
||||
}
|
||||
|
||||
func collectRequestToolNames(data []byte) map[string]string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
parsed := gjson.ParseBytes(data)
|
||||
names := map[string]string{}
|
||||
conflicts := map[string]bool{}
|
||||
record := func(name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if conflicts[key] {
|
||||
return
|
||||
}
|
||||
if existing, exists := names[key]; exists {
|
||||
if existing != name {
|
||||
names[key] = ""
|
||||
conflicts[key] = true
|
||||
}
|
||||
return
|
||||
}
|
||||
names[key] = name
|
||||
}
|
||||
|
||||
for _, tool := range parsed.Get("tools").Array() {
|
||||
record(tool.Get("name").String())
|
||||
}
|
||||
if parsed.Get("tool_choice.type").String() == "tool" {
|
||||
record(parsed.Get("tool_choice.name").String())
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func canonicalAmpToolName(name string, requestToolNames map[string]string) (string, bool) {
|
||||
key := strings.ToLower(name)
|
||||
if canonical, ok := requestToolNames[key]; ok {
|
||||
if canonical == "" {
|
||||
return "", false
|
||||
}
|
||||
return canonical, true
|
||||
}
|
||||
canonical, ok := ampCanonicalToolNames[key]
|
||||
return canonical, ok
|
||||
}
|
||||
|
||||
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
|
||||
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
|
||||
// which causes Amp's case-sensitive mode whitelist to reject them.
|
||||
func normalizeAmpToolNames(data []byte) []byte {
|
||||
return normalizeAmpToolNamesForRequest(data, nil)
|
||||
}
|
||||
|
||||
func normalizeAmpToolNamesForRequest(data []byte, requestToolNames map[string]string) []byte {
|
||||
// Non-streaming: content[].name in tool_use blocks
|
||||
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical {
|
||||
path := fmt.Sprintf("content.%d.name", index)
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, path, canonical)
|
||||
@@ -157,7 +217,7 @@ func normalizeAmpToolNames(data []byte) []byte {
|
||||
// Streaming: content_block.name in content_block_start events
|
||||
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
|
||||
name := gjson.GetBytes(data, "content_block.name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content_block.name", canonical)
|
||||
if err != nil {
|
||||
@@ -169,6 +229,10 @@ func normalizeAmpToolNames(data []byte) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) normalizeToolNames(data []byte) []byte {
|
||||
return normalizeAmpToolNamesForRequest(data, rw.requestToolNames)
|
||||
}
|
||||
|
||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||
func ensureAmpSignature(data []byte) []byte {
|
||||
@@ -225,7 +289,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
data = normalizeAmpToolNames(data)
|
||||
data = rw.normalizeToolNames(data)
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
@@ -326,7 +390,7 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
|
||||
// Normalize tool names to canonical casing
|
||||
data = normalizeAmpToolNames(data)
|
||||
data = rw.normalizeToolNames(data)
|
||||
|
||||
// Rewrite model name
|
||||
if rw.originalModel != "" {
|
||||
|
||||
@@ -217,6 +217,96 @@ func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_RequestToolCasing_NonStreaming(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||
result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"})
|
||||
|
||||
if !contains(result, []byte(`"name":"Glob"`)) {
|
||||
t.Errorf("expected glob->Glob when request advertised Glob, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_RequestToolCasing_Streaming(t *testing.T) {
|
||||
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"glob","id":"toolu_01","input":{}}}`)
|
||||
result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"})
|
||||
|
||||
if !contains(result, []byte(`"name":"Glob"`)) {
|
||||
t.Errorf("expected glob->Glob in streaming when request advertised Glob, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRewriter_RequestToolCasingFromBody(t *testing.T) {
|
||||
requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
|
||||
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Glob"`)) {
|
||||
t.Errorf("expected request body casing to restore glob->Glob, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRewriter_LowercaseNativeRequestPreserved(t *testing.T) {
|
||||
requestBody := []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}}]}`)
|
||||
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) == string(input) {
|
||||
return
|
||||
}
|
||||
if !contains(result, []byte(`"name":"glob"`)) {
|
||||
t.Errorf("expected lowercase-native request to preserve glob, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectRequestToolNames_CollisionIgnored(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestBody []byte
|
||||
input []byte
|
||||
forbidden []byte
|
||||
}{
|
||||
{
|
||||
requestBody: []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}},{"name":"glob","input_schema":{"type":"object"}}]}`),
|
||||
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`),
|
||||
forbidden: []byte(`"name":"Glob"`),
|
||||
},
|
||||
{
|
||||
requestBody: []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}},{"name":"Glob","input_schema":{"type":"object"}}]}`),
|
||||
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`),
|
||||
forbidden: []byte(`"name":"Glob"`),
|
||||
},
|
||||
{
|
||||
requestBody: []byte(`{"tools":[{"name":"Bash","input_schema":{"type":"object"}},{"name":"bash","input_schema":{"type":"object"}}]}`),
|
||||
input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}}]}`),
|
||||
forbidden: []byte(`"name":"Bash"`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(tt.requestBody)}
|
||||
result := rw.rewriteModelInResponse(tt.input)
|
||||
|
||||
if contains(result, tt.forbidden) {
|
||||
t.Errorf("expected conflicting tool casing not to force %s, got %s", string(tt.forbidden), string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRewriter_RequestToolCasingFromBody_Streaming(t *testing.T) {
|
||||
requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`)
|
||||
rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)}
|
||||
input := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n")
|
||||
|
||||
result := rw.rewriteStreamChunk(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Glob"`)) {
|
||||
t.Errorf("expected streaming response to restore glob->Glob from request body, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
Reference in New Issue
Block a user