mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-20 09:17:43 +08:00
Merge pull request #2896 from edlsh/fix/oauth-tool-rename-per-request-map
fix(amp): smart-mode tool name fixes + deep-mode response repair
This commit is contained in:
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
|
||||
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// ampCanonicalToolNames maps tool names to the exact casing expected by the
|
||||
// Amp mode tool whitelist (case-sensitive match).
|
||||
var ampCanonicalToolNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"grep": "Grep",
|
||||
"glob": "glob",
|
||||
"task": "Task",
|
||||
"check": "Check",
|
||||
}
|
||||
|
||||
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
|
||||
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
|
||||
// which causes Amp's case-sensitive mode whitelist to reject them.
|
||||
func normalizeAmpToolNames(data []byte) []byte {
|
||||
// Non-streaming: content[].name in tool_use blocks
|
||||
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
path := fmt.Sprintf("content.%d.name", index)
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, path, canonical)
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming: content_block.name in content_block_start events
|
||||
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
|
||||
name := gjson.GetBytes(data, "content_block.name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content_block.name", canonical)
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||
func ensureAmpSignature(data []byte) []byte {
|
||||
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
data = normalizeAmpToolNames(data)
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
// Inject empty signature where needed
|
||||
data = ensureAmpSignature(data)
|
||||
|
||||
// Normalize tool names to canonical casing
|
||||
data = normalizeAmpToolNames(data)
|
||||
|
||||
// Rewrite model name
|
||||
if rw.originalModel != "" {
|
||||
for _, path := range modelFieldPaths {
|
||||
|
||||
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Bash"`)) {
|
||||
t.Errorf("expected bash->Bash, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"name":"Read"`)) {
|
||||
t.Errorf("expected read->Read, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte(`"name":"bash"`)) {
|
||||
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
|
||||
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Grep"`)) {
|
||||
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected glob to remain lowercase, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification for unknown tool, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
|
||||
Reference in New Issue
Block a user