mirror of
https://github.com/7836246/cursor2api.git
synced 2026-06-01 19:39:47 +08:00
feat: 实现 Anthropic tool_use 协议支持
- internal/tools/types.go: 定义工具相关类型 - internal/tools/executor.go: 工具执行器(bash/文件操作) - internal/tools/parser.go: 解析 AI 输出中的工具调用 - internal/handler/anthropic.go: 支持 tools 参数和 tool_use 响应 - internal/handler/tools.go: 工具执行和列表接口 - cmd/server/main.go: 注册工具路由 支持的工具: - bash: 执行命令 - read_file: 读取文件 - write_file: 写入文件 - list_dir: 列出目录 - edit: 查找替换编辑
This commit is contained in:
@@ -38,6 +38,10 @@ func main() {
|
||||
r.POST("/v1/messages/count_tokens", handler.CountTokens)
|
||||
r.POST("/messages/count_tokens", handler.CountTokens)
|
||||
|
||||
// 工具相关接口
|
||||
r.GET("/tools", handler.ListTools)
|
||||
r.POST("/tools/execute", handler.ExecuteTool)
|
||||
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cursor2api/internal/browser"
|
||||
"cursor2api/internal/tools"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -17,17 +18,18 @@ import (
|
||||
|
||||
// MessagesRequest Anthropic Messages API 请求格式
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
|
||||
Tools []tools.ToolDefinition `json:"tools,omitempty"` // 工具定义
|
||||
}
|
||||
|
||||
// Message 消息格式
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock
|
||||
Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock 或 []ToolResult
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages API 响应格式
|
||||
@@ -42,10 +44,13 @@ type MessagesResponse struct {
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// ContentBlock 内容块
|
||||
// ContentBlock 内容块(支持 text 和 tool_use)
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"` // tool_use
|
||||
Name string `json:"name,omitempty"` // tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // tool_use
|
||||
}
|
||||
|
||||
// Usage token 使用统计
|
||||
@@ -54,6 +59,17 @@ type Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// 全局工具执行器和解析器
|
||||
var (
|
||||
toolExecutor *tools.Executor
|
||||
toolParser *tools.Parser
|
||||
)
|
||||
|
||||
func init() {
|
||||
toolExecutor = tools.NewExecutor()
|
||||
toolParser = tools.NewParser()
|
||||
}
|
||||
|
||||
// CursorSSEEvent Cursor SSE 事件格式
|
||||
type CursorSSEEvent struct {
|
||||
Type string `json:"type"`
|
||||
@@ -168,8 +184,14 @@ func Messages(c *gin.Context) {
|
||||
func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
messages := make([]browser.CursorMessage, 0, len(req.Messages)+1)
|
||||
|
||||
// 添加 system 消息
|
||||
if sysText := getTextContent(req.System); sysText != "" {
|
||||
// 构建系统消息(包含工具定义)
|
||||
sysText := getTextContent(req.System)
|
||||
if len(req.Tools) > 0 {
|
||||
toolPrompt := tools.GenerateToolPrompt(req.Tools)
|
||||
sysText += toolPrompt
|
||||
}
|
||||
|
||||
if sysText != "" {
|
||||
messages = append(messages, browser.CursorMessage{
|
||||
Parts: []browser.CursorPart{{Type: "text", Text: sysText}},
|
||||
ID: generateID(),
|
||||
@@ -179,11 +201,14 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
|
||||
// 添加用户/助手消息
|
||||
for _, msg := range req.Messages {
|
||||
messages = append(messages, browser.CursorMessage{
|
||||
Parts: []browser.CursorPart{{Type: "text", Text: getTextContent(msg.Content)}},
|
||||
ID: generateID(),
|
||||
Role: msg.Role,
|
||||
})
|
||||
text := extractMessageText(msg)
|
||||
if text != "" {
|
||||
messages = append(messages, browser.CursorMessage{
|
||||
Parts: []browser.CursorPart{{Type: "text", Text: text}},
|
||||
ID: generateID(),
|
||||
Role: msg.Role,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return browser.CursorChatRequest{
|
||||
@@ -199,6 +224,65 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// extractMessageText 从消息中提取文本(处理 tool_result)
|
||||
func extractMessageText(msg Message) string {
|
||||
content := msg.Content
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var texts []string
|
||||
for _, item := range v {
|
||||
block, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
case "tool_result":
|
||||
// 处理工具结果
|
||||
toolUseID, _ := block["tool_use_id"].(string)
|
||||
resultContent := block["content"]
|
||||
isError, _ := block["is_error"].(bool)
|
||||
|
||||
resultText := ""
|
||||
switch rc := resultContent.(type) {
|
||||
case string:
|
||||
resultText = rc
|
||||
case []interface{}:
|
||||
for _, rcItem := range rc {
|
||||
if rcBlock, ok := rcItem.(map[string]interface{}); ok {
|
||||
if rcBlock["type"] == "text" {
|
||||
if t, ok := rcBlock["text"].(string); ok {
|
||||
resultText += t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := "工具执行结果"
|
||||
if isError {
|
||||
prefix = "工具执行错误"
|
||||
}
|
||||
texts = append(texts, fmt.Sprintf("[%s (ID: %s)]\n%s", prefix, toolUseID, resultText))
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// ================== API 处理 ==================
|
||||
|
||||
// handleStream 处理流式请求
|
||||
@@ -220,8 +304,9 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
c.Writer.WriteString(`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
// 用于累积不完整的 SSE 行
|
||||
// 用于累积完整响应和 SSE 行
|
||||
var buffer strings.Builder
|
||||
var fullResponse strings.Builder
|
||||
|
||||
svc := browser.GetService()
|
||||
err := svc.SendStreamRequest(cursorReq, func(chunk string) {
|
||||
@@ -253,6 +338,7 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
}
|
||||
|
||||
if event.Type == "text-delta" && event.Delta != "" {
|
||||
fullResponse.WriteString(event.Delta)
|
||||
deltaJSON, _ := json.Marshal(event.Delta)
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + string(deltaJSON) + `}}` + "\n\n")
|
||||
@@ -271,8 +357,34 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
c.Writer.WriteString(`data: {"type":"content_block_stop","index":0}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
// 检测是否有工具调用,决定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
responseText := fullResponse.String()
|
||||
toolCalls, _ := toolParser.ParseToolCalls(responseText)
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
// 发送工具调用块
|
||||
for i, call := range toolCalls {
|
||||
toolID := "toolu_" + generateID()
|
||||
inputJSON, _ := json.Marshal(call.Input)
|
||||
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`+"\n\n", i+1, toolID, call.Name))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":"%s"}}`+"\n\n", i+1, escapeJSON(string(inputJSON))))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", i+1))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.WriteString("event: message_delta\n")
|
||||
c.Writer.WriteString(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":100}}` + "\n\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":100}}`+"\n\n", stopReason))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: message_stop\n")
|
||||
@@ -280,6 +392,16 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// escapeJSON 转义 JSON 字符串中的特殊字符
|
||||
func escapeJSON(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
s = strings.ReplaceAll(s, "\n", `\n`)
|
||||
s = strings.ReplaceAll(s, "\r", `\r`)
|
||||
s = strings.ReplaceAll(s, "\t", `\t`)
|
||||
return s
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string) {
|
||||
svc := browser.GetService()
|
||||
@@ -311,13 +433,61 @@ func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model
|
||||
}
|
||||
}
|
||||
|
||||
responseText := fullText.String()
|
||||
contentBlocks := parseResponseToBlocks(responseText)
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
for _, block := range contentBlocks {
|
||||
if block.Type == "tool_use" {
|
||||
stopReason = "tool_use"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MessagesResponse{
|
||||
ID: "msg_" + generateID(),
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: []ContentBlock{{Type: "text", Text: fullText.String()}},
|
||||
Content: contentBlocks,
|
||||
Model: model,
|
||||
StopReason: "end_turn",
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{InputTokens: 100, OutputTokens: 100},
|
||||
})
|
||||
}
|
||||
|
||||
// parseResponseToBlocks 解析 AI 响应为内容块(检测工具调用)
|
||||
func parseResponseToBlocks(text string) []ContentBlock {
|
||||
var blocks []ContentBlock
|
||||
|
||||
// 检测工具调用
|
||||
toolCalls, remainingText := toolParser.ParseToolCalls(text)
|
||||
|
||||
// 添加文本块
|
||||
if remainingText != "" {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: remainingText,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加工具调用块
|
||||
for _, call := range toolCalls {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_" + generateID(),
|
||||
Name: call.Name,
|
||||
Input: call.Input,
|
||||
})
|
||||
}
|
||||
|
||||
// 如果没有任何内容,添加空文本块
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
|
||||
125
internal/handler/tools.go
Normal file
125
internal/handler/tools.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"cursor2api/internal/tools"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ToolExecuteRequest 工具执行请求
|
||||
type ToolExecuteRequest struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
WorkDir string `json:"work_dir,omitempty"`
|
||||
}
|
||||
|
||||
// ToolExecuteResponse 工具执行响应
|
||||
type ToolExecuteResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteTool 执行工具调用(供本地测试使用)
|
||||
func ExecuteTool(c *gin.Context) {
|
||||
var req ToolExecuteRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ToolExecuteResponse{
|
||||
Success: false,
|
||||
Error: "无效的请求格式: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
executor := tools.NewExecutor()
|
||||
if req.WorkDir != "" {
|
||||
executor.SetWorkDir(req.WorkDir)
|
||||
}
|
||||
|
||||
output, err := executor.Execute(req.ToolName, req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, ToolExecuteResponse{
|
||||
Success: false,
|
||||
Output: output,
|
||||
Error: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ToolExecuteResponse{
|
||||
Success: true,
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
// ListTools 列出可用工具
|
||||
func ListTools(c *gin.Context) {
|
||||
toolList := []tools.ToolDefinition{
|
||||
{
|
||||
Name: "bash",
|
||||
Description: "执行 bash 命令",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"command": {Type: "string", Description: "要执行的命令"},
|
||||
"cwd": {Type: "string", Description: "工作目录(可选)"},
|
||||
},
|
||||
Required: []string{"command"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "read_file",
|
||||
Description: "读取文件内容",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "write_file",
|
||||
Description: "写入文件",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
"content": {Type: "string", Description: "文件内容"},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list_dir",
|
||||
Description: "列出目录内容",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "目录路径"},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "edit",
|
||||
Description: "编辑文件(查找替换)",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
"old_string": {Type: "string", Description: "要替换的内容"},
|
||||
"new_string": {Type: "string", Description: "替换后的内容"},
|
||||
"replace_all": {Type: "boolean", Description: "是否替换所有匹配"},
|
||||
},
|
||||
Required: []string{"path", "old_string", "new_string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tools": toolList,
|
||||
})
|
||||
}
|
||||
260
internal/tools/executor.go
Normal file
260
internal/tools/executor.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Executor 工具执行器
|
||||
type Executor struct {
|
||||
workDir string
|
||||
allowedDirs []string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewExecutor 创建工具执行器
|
||||
func NewExecutor() *Executor {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return &Executor{
|
||||
workDir: homeDir,
|
||||
allowedDirs: []string{homeDir, "/tmp"},
|
||||
timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// SetWorkDir 设置工作目录
|
||||
func (e *Executor) SetWorkDir(dir string) {
|
||||
e.workDir = dir
|
||||
}
|
||||
|
||||
// Execute 执行工具调用
|
||||
func (e *Executor) Execute(toolName string, input map[string]interface{}) (string, error) {
|
||||
switch toolName {
|
||||
case "bash", "run_command":
|
||||
return e.executeBash(input)
|
||||
case "read_file":
|
||||
return e.readFile(input)
|
||||
case "write_file", "write_to_file":
|
||||
return e.writeFile(input)
|
||||
case "list_dir", "list_directory":
|
||||
return e.listDir(input)
|
||||
case "edit", "str_replace_editor":
|
||||
return e.editFile(input)
|
||||
default:
|
||||
return "", fmt.Errorf("未知工具: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
// executeBash 执行 bash 命令
|
||||
func (e *Executor) executeBash(input map[string]interface{}) (string, error) {
|
||||
command, ok := input["command"].(string)
|
||||
if !ok {
|
||||
// 尝试其他字段名
|
||||
if cmd, ok := input["CommandLine"].(string); ok {
|
||||
command = cmd
|
||||
} else {
|
||||
return "", fmt.Errorf("缺少 command 参数")
|
||||
}
|
||||
}
|
||||
|
||||
// 获取工作目录
|
||||
cwd := e.workDir
|
||||
if dir, ok := input["cwd"].(string); ok && dir != "" {
|
||||
cwd = dir
|
||||
} else if dir, ok := input["Cwd"].(string); ok && dir != "" {
|
||||
cwd = dir
|
||||
}
|
||||
|
||||
cmd := exec.Command("bash", "-c", command)
|
||||
cmd.Dir = cwd
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// 设置超时
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Run()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
if output != "" {
|
||||
output += "\n"
|
||||
}
|
||||
output += stderr.String()
|
||||
}
|
||||
if err != nil {
|
||||
return output, fmt.Errorf("命令执行失败: %v\n%s", err, output)
|
||||
}
|
||||
return output, nil
|
||||
case <-time.After(e.timeout):
|
||||
cmd.Process.Kill()
|
||||
return "", fmt.Errorf("命令执行超时 (%v)", e.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// readFile 读取文件
|
||||
func (e *Executor) readFile(input map[string]interface{}) (string, error) {
|
||||
path, ok := input["path"].(string)
|
||||
if !ok {
|
||||
if p, ok := input["file_path"].(string); ok {
|
||||
path = p
|
||||
} else {
|
||||
return "", fmt.Errorf("缺少 path 参数")
|
||||
}
|
||||
}
|
||||
|
||||
// 处理相对路径
|
||||
if !filepath.IsAbs(path) {
|
||||
path = filepath.Join(e.workDir, path)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取文件失败: %v", err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// writeFile 写入文件
|
||||
func (e *Executor) writeFile(input map[string]interface{}) (string, error) {
|
||||
path, ok := input["path"].(string)
|
||||
if !ok {
|
||||
if p, ok := input["file_path"].(string); ok {
|
||||
path = p
|
||||
} else if p, ok := input["TargetFile"].(string); ok {
|
||||
path = p
|
||||
} else {
|
||||
return "", fmt.Errorf("缺少 path 参数")
|
||||
}
|
||||
}
|
||||
|
||||
content, ok := input["content"].(string)
|
||||
if !ok {
|
||||
if c, ok := input["CodeContent"].(string); ok {
|
||||
content = c
|
||||
} else {
|
||||
return "", fmt.Errorf("缺少 content 参数")
|
||||
}
|
||||
}
|
||||
|
||||
// 处理相对路径
|
||||
if !filepath.IsAbs(path) {
|
||||
path = filepath.Join(e.workDir, path)
|
||||
}
|
||||
|
||||
// 创建目录
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建目录失败: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return "", fmt.Errorf("写入文件失败: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("已写入文件: %s (%d 字节)", path, len(content)), nil
|
||||
}
|
||||
|
||||
// listDir 列出目录内容
|
||||
func (e *Executor) listDir(input map[string]interface{}) (string, error) {
|
||||
path, ok := input["path"].(string)
|
||||
if !ok {
|
||||
if p, ok := input["DirectoryPath"].(string); ok {
|
||||
path = p
|
||||
} else {
|
||||
path = e.workDir
|
||||
}
|
||||
}
|
||||
|
||||
// 处理相对路径
|
||||
if !filepath.IsAbs(path) {
|
||||
path = filepath.Join(e.workDir, path)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取目录失败: %v", err)
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("目录: %s\n\n", path))
|
||||
|
||||
for _, entry := range entries {
|
||||
info, _ := entry.Info()
|
||||
if entry.IsDir() {
|
||||
result.WriteString(fmt.Sprintf("[DIR] %s/\n", entry.Name()))
|
||||
} else {
|
||||
size := int64(0)
|
||||
if info != nil {
|
||||
size = info.Size()
|
||||
}
|
||||
result.WriteString(fmt.Sprintf("[FILE] %s (%d bytes)\n", entry.Name(), size))
|
||||
}
|
||||
}
|
||||
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
// editFile 编辑文件(查找替换)
|
||||
func (e *Executor) editFile(input map[string]interface{}) (string, error) {
|
||||
path, ok := input["path"].(string)
|
||||
if !ok {
|
||||
if p, ok := input["file_path"].(string); ok {
|
||||
path = p
|
||||
} else {
|
||||
return "", fmt.Errorf("缺少 path 参数")
|
||||
}
|
||||
}
|
||||
|
||||
oldStr, _ := input["old_string"].(string)
|
||||
newStr, _ := input["new_string"].(string)
|
||||
|
||||
if oldStr == "" {
|
||||
return "", fmt.Errorf("缺少 old_string 参数")
|
||||
}
|
||||
|
||||
// 处理相对路径
|
||||
if !filepath.IsAbs(path) {
|
||||
path = filepath.Join(e.workDir, path)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取文件失败: %v", err)
|
||||
}
|
||||
|
||||
original := string(content)
|
||||
if !strings.Contains(original, oldStr) {
|
||||
return "", fmt.Errorf("未找到要替换的内容")
|
||||
}
|
||||
|
||||
// 替换
|
||||
replaceAll := false
|
||||
if ra, ok := input["replace_all"].(bool); ok {
|
||||
replaceAll = ra
|
||||
}
|
||||
|
||||
var modified string
|
||||
if replaceAll {
|
||||
modified = strings.ReplaceAll(original, oldStr, newStr)
|
||||
} else {
|
||||
modified = strings.Replace(original, oldStr, newStr, 1)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(modified), 0644); err != nil {
|
||||
return "", fmt.Errorf("写入文件失败: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("已编辑文件: %s", path), nil
|
||||
}
|
||||
144
internal/tools/parser.go
Normal file
144
internal/tools/parser.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parser 解析 AI 输出中的工具调用
|
||||
type Parser struct{}
|
||||
|
||||
// NewParser 创建解析器
|
||||
func NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
|
||||
// toolCallPattern 匹配工具调用的 JSON 块
|
||||
var toolCallPatterns = []*regexp.Regexp{
|
||||
// 标准 JSON 块格式
|
||||
regexp.MustCompile(`(?s)<tool_call>\s*(\{.*?\})\s*</tool_call>`),
|
||||
// 代码块格式
|
||||
regexp.MustCompile("(?s)```json\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"),
|
||||
regexp.MustCompile("(?s)```\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"),
|
||||
// 单行 JSON 格式
|
||||
regexp.MustCompile(`(\{"tool"\s*:\s*"[^"]+"\s*,\s*"[^}]+\})`),
|
||||
}
|
||||
|
||||
// ParseToolCalls 从 AI 输出中解析工具调用
|
||||
func (p *Parser) ParseToolCalls(output string) ([]ParsedToolCall, string) {
|
||||
var calls []ParsedToolCall
|
||||
remainingText := output
|
||||
|
||||
for _, pattern := range toolCallPatterns {
|
||||
matches := pattern.FindAllStringSubmatch(output, -1)
|
||||
for _, match := range matches {
|
||||
if len(match) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := match[1]
|
||||
var rawCall map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &rawCall); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 提取工具名称
|
||||
toolName := ""
|
||||
if name, ok := rawCall["tool"].(string); ok {
|
||||
toolName = name
|
||||
} else if name, ok := rawCall["name"].(string); ok {
|
||||
toolName = name
|
||||
}
|
||||
|
||||
if toolName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 提取输入参数
|
||||
input := make(map[string]interface{})
|
||||
if inp, ok := rawCall["input"].(map[string]interface{}); ok {
|
||||
input = inp
|
||||
} else {
|
||||
// 其他字段作为输入
|
||||
for k, v := range rawCall {
|
||||
if k != "tool" && k != "name" && k != "type" {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
calls = append(calls, ParsedToolCall{
|
||||
Name: toolName,
|
||||
Input: input,
|
||||
})
|
||||
|
||||
// 从剩余文本中移除已解析的工具调用
|
||||
remainingText = strings.Replace(remainingText, match[0], "", 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 清理剩余文本
|
||||
remainingText = strings.TrimSpace(remainingText)
|
||||
|
||||
return calls, remainingText
|
||||
}
|
||||
|
||||
// GenerateToolPrompt 生成工具使用的系统提示
|
||||
func GenerateToolPrompt(tools []ToolDefinition) string {
|
||||
if len(tools) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\n\n## 可用工具\n\n")
|
||||
sb.WriteString("当你需要执行操作时,请使用以下格式调用工具:\n\n")
|
||||
sb.WriteString("<tool_call>\n{\"tool\": \"工具名称\", \"参数名\": \"参数值\"}\n</tool_call>\n\n")
|
||||
sb.WriteString("可用的工具:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("### ")
|
||||
sb.WriteString(tool.Name)
|
||||
sb.WriteString("\n")
|
||||
if tool.Description != "" {
|
||||
sb.WriteString(tool.Description)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("参数:\n")
|
||||
|
||||
for name, prop := range tool.InputSchema.Properties {
|
||||
required := ""
|
||||
for _, r := range tool.InputSchema.Required {
|
||||
if r == name {
|
||||
required = " (必需)"
|
||||
break
|
||||
}
|
||||
}
|
||||
sb.WriteString("- `")
|
||||
sb.WriteString(name)
|
||||
sb.WriteString("`")
|
||||
sb.WriteString(required)
|
||||
sb.WriteString(": ")
|
||||
sb.WriteString(prop.Description)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("重要提示:\n")
|
||||
sb.WriteString("- 每次只调用一个工具\n")
|
||||
sb.WriteString("- 工具调用必须使用 <tool_call> 标签包裹\n")
|
||||
sb.WriteString("- 等待工具执行结果后再继续\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// IsToolCallResponse 检查输出是否包含工具调用
|
||||
func (p *Parser) IsToolCallResponse(output string) bool {
|
||||
for _, pattern := range toolCallPatterns {
|
||||
if pattern.MatchString(output) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
51
internal/tools/types.go
Normal file
51
internal/tools/types.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Package tools 提供工具调用解析和执行功能
|
||||
package tools
|
||||
|
||||
// ToolDefinition Anthropic 工具定义格式
|
||||
type ToolDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema InputSchema `json:"input_schema"`
|
||||
}
|
||||
|
||||
// InputSchema JSON Schema 格式的输入参数定义
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]Property `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// Property JSON Schema 属性定义
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// ToolUse AI 返回的工具调用
|
||||
type ToolUse struct {
|
||||
Type string `json:"type"` // "tool_use"
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
// ToolResult 工具执行结果
|
||||
type ToolResult struct {
|
||||
Type string `json:"type"` // "tool_result"
|
||||
ToolUseID string `json:"tool_use_id"`
|
||||
Content interface{} `json:"content"` // string 或 []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlock 内容块(用于 tool_result)
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// ParsedToolCall 从 AI 输出解析的工具调用
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
Reference in New Issue
Block a user