From c68aa3bfb682f81c39727e92d0ee6c5c672ea09d Mon Sep 17 00:00:00 2001 From: chinadoiphin Date: Tue, 16 Dec 2025 20:32:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20Anthropic=20tool?= =?UTF-8?q?=5Fuse=20=E5=8D=8F=E8=AE=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - internal/tools/types.go: 定义工具相关类型 - internal/tools/executor.go: 工具执行器(bash/文件操作) - internal/tools/parser.go: 解析 AI 输出中的工具调用 - internal/handler/anthropic.go: 支持 tools 参数和 tool_use 响应 - internal/handler/tools.go: 工具执行和列表接口 - cmd/server/main.go: 注册工具路由 支持的工具: - bash: 执行命令 - read_file: 读取文件 - write_file: 写入文件 - list_dir: 列出目录 - edit: 查找替换编辑 --- cmd/server/main.go | 4 + internal/handler/anthropic.go | 210 ++++++++++++++++++++++++--- internal/handler/tools.go | 125 ++++++++++++++++ internal/tools/executor.go | 260 ++++++++++++++++++++++++++++++++++ internal/tools/parser.go | 144 +++++++++++++++++++ internal/tools/types.go | 51 +++++++ 6 files changed, 774 insertions(+), 20 deletions(-) create mode 100644 internal/handler/tools.go create mode 100644 internal/tools/executor.go create mode 100644 internal/tools/parser.go create mode 100644 internal/tools/types.go diff --git a/cmd/server/main.go b/cmd/server/main.go index cd54593..e65429b 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -38,6 +38,10 @@ func main() { r.POST("/v1/messages/count_tokens", handler.CountTokens) r.POST("/messages/count_tokens", handler.CountTokens) + // 工具相关接口 + r.GET("/tools", handler.ListTools) + r.POST("/tools/execute", handler.ExecuteTool) + // 健康检查 r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) diff --git a/internal/handler/anthropic.go b/internal/handler/anthropic.go index d961fb6..1cba5d2 100644 --- a/internal/handler/anthropic.go +++ b/internal/handler/anthropic.go @@ -8,6 +8,7 @@ import ( "strings" "cursor2api/internal/browser" + "cursor2api/internal/tools" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -17,17 +18,18 @@ import ( // MessagesRequest Anthropic Messages API 请求格式 type MessagesRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` - Stream bool `json:"stream"` - System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock + Tools []tools.ToolDefinition `json:"tools,omitempty"` // 工具定义 } // Message 消息格式 type Message struct { Role string `json:"role"` - Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock + Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock 或 []ToolResult } // MessagesResponse Anthropic Messages API 响应格式 @@ -42,10 +44,13 @@ type MessagesResponse struct { Usage Usage `json:"usage"` } -// ContentBlock 内容块 +// ContentBlock 内容块(支持 text 和 tool_use) type ContentBlock struct { - Type string `json:"type"` - Text string `json:"text"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` // tool_use + Name string `json:"name,omitempty"` // tool_use + Input map[string]interface{} `json:"input,omitempty"` // tool_use } // Usage token 使用统计 @@ -54,6 +59,17 @@ type Usage struct { OutputTokens int `json:"output_tokens"` } +// 全局工具执行器和解析器 +var ( + toolExecutor *tools.Executor + toolParser *tools.Parser +) + +func init() { + toolExecutor = tools.NewExecutor() + toolParser = tools.NewParser() +} + // CursorSSEEvent Cursor SSE 事件格式 type CursorSSEEvent struct { Type string `json:"type"` @@ -168,8 +184,14 @@ func Messages(c *gin.Context) { func convertToCursor(req MessagesRequest) browser.CursorChatRequest { messages := make([]browser.CursorMessage, 0, len(req.Messages)+1) - // 添加 system 消息 - if sysText := getTextContent(req.System); sysText != "" { + // 构建系统消息(包含工具定义) + sysText := getTextContent(req.System) + if len(req.Tools) > 0 { + toolPrompt := tools.GenerateToolPrompt(req.Tools) + sysText += toolPrompt + } + + if sysText != "" { messages = append(messages, browser.CursorMessage{ Parts: []browser.CursorPart{{Type: "text", Text: sysText}}, ID: generateID(), @@ -179,11 +201,14 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest { // 添加用户/助手消息 for _, msg := range req.Messages { - messages = append(messages, browser.CursorMessage{ - Parts: []browser.CursorPart{{Type: "text", Text: getTextContent(msg.Content)}}, - ID: generateID(), - Role: msg.Role, - }) + text := extractMessageText(msg) + if text != "" { + messages = append(messages, browser.CursorMessage{ + Parts: []browser.CursorPart{{Type: "text", Text: text}}, + ID: generateID(), + Role: msg.Role, + }) + } } return browser.CursorChatRequest{ @@ -199,6 +224,65 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest { } } +// extractMessageText 从消息中提取文本(处理 tool_result) +func extractMessageText(msg Message) string { + content := msg.Content + if content == nil { + return "" + } + + switch v := content.(type) { + case string: + return v + case []interface{}: + var texts []string + for _, item := range v { + block, ok := item.(map[string]interface{}) + if !ok { + continue + } + + blockType, _ := block["type"].(string) + switch blockType { + case "text": + if text, ok := block["text"].(string); ok { + texts = append(texts, text) + } + case "tool_result": + // 处理工具结果 + toolUseID, _ := block["tool_use_id"].(string) + resultContent := block["content"] + isError, _ := block["is_error"].(bool) + + resultText := "" + switch rc := resultContent.(type) { + case string: + resultText = rc + case []interface{}: + for _, rcItem := range rc { + if rcBlock, ok := rcItem.(map[string]interface{}); ok { + if rcBlock["type"] == "text" { + if t, ok := rcBlock["text"].(string); ok { + resultText += t + } + } + } + } + } + + prefix := "工具执行结果" + if isError { + prefix = "工具执行错误" + } + texts = append(texts, fmt.Sprintf("[%s (ID: %s)]\n%s", prefix, toolUseID, resultText)) + } + } + return strings.Join(texts, "\n") + default: + return fmt.Sprintf("%v", v) + } +} + // ================== API 处理 ================== // handleStream 处理流式请求 @@ -220,8 +304,9 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str c.Writer.WriteString(`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n") flusher.Flush() - // 用于累积不完整的 SSE 行 + // 用于累积完整响应和 SSE 行 var buffer strings.Builder + var fullResponse strings.Builder svc := browser.GetService() err := svc.SendStreamRequest(cursorReq, func(chunk string) { @@ -253,6 +338,7 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str } if event.Type == "text-delta" && event.Delta != "" { + fullResponse.WriteString(event.Delta) deltaJSON, _ := json.Marshal(event.Delta) c.Writer.WriteString("event: content_block_delta\n") c.Writer.WriteString(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + string(deltaJSON) + `}}` + "\n\n") @@ -271,8 +357,34 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str c.Writer.WriteString(`data: {"type":"content_block_stop","index":0}` + "\n\n") flusher.Flush() + // 检测是否有工具调用,决定 stop_reason + stopReason := "end_turn" + responseText := fullResponse.String() + toolCalls, _ := toolParser.ParseToolCalls(responseText) + + if len(toolCalls) > 0 { + stopReason = "tool_use" + // 发送工具调用块 + for i, call := range toolCalls { + toolID := "toolu_" + generateID() + inputJSON, _ := json.Marshal(call.Input) + + c.Writer.WriteString("event: content_block_start\n") + c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`+"\n\n", i+1, toolID, call.Name)) + flusher.Flush() + + c.Writer.WriteString("event: content_block_delta\n") + c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":"%s"}}`+"\n\n", i+1, escapeJSON(string(inputJSON)))) + flusher.Flush() + + c.Writer.WriteString("event: content_block_stop\n") + c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", i+1)) + flusher.Flush() + } + } + c.Writer.WriteString("event: message_delta\n") - c.Writer.WriteString(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":100}}` + "\n\n") + c.Writer.WriteString(fmt.Sprintf(`data: {"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":100}}`+"\n\n", stopReason)) flusher.Flush() c.Writer.WriteString("event: message_stop\n") @@ -280,6 +392,16 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str flusher.Flush() } +// escapeJSON 转义 JSON 字符串中的特殊字符 +func escapeJSON(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\"`) + s = strings.ReplaceAll(s, "\n", `\n`) + s = strings.ReplaceAll(s, "\r", `\r`) + s = strings.ReplaceAll(s, "\t", `\t`) + return s +} + // handleNonStream 处理非流式请求 func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string) { svc := browser.GetService() @@ -311,13 +433,61 @@ func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model } } + responseText := fullText.String() + contentBlocks := parseResponseToBlocks(responseText) + + // 确定 stop_reason + stopReason := "end_turn" + for _, block := range contentBlocks { + if block.Type == "tool_use" { + stopReason = "tool_use" + break + } + } + c.JSON(http.StatusOK, MessagesResponse{ ID: "msg_" + generateID(), Type: "message", Role: "assistant", - Content: []ContentBlock{{Type: "text", Text: fullText.String()}}, + Content: contentBlocks, Model: model, - StopReason: "end_turn", + StopReason: stopReason, Usage: Usage{InputTokens: 100, OutputTokens: 100}, }) } + +// parseResponseToBlocks 解析 AI 响应为内容块(检测工具调用) +func parseResponseToBlocks(text string) []ContentBlock { + var blocks []ContentBlock + + // 检测工具调用 + toolCalls, remainingText := toolParser.ParseToolCalls(text) + + // 添加文本块 + if remainingText != "" { + blocks = append(blocks, ContentBlock{ + Type: "text", + Text: remainingText, + }) + } + + // 添加工具调用块 + for _, call := range toolCalls { + blocks = append(blocks, ContentBlock{ + Type: "tool_use", + ID: "toolu_" + generateID(), + Name: call.Name, + Input: call.Input, + }) + } + + // 如果没有任何内容,添加空文本块 + if len(blocks) == 0 { + blocks = append(blocks, ContentBlock{ + Type: "text", + Text: text, + }) + } + + return blocks +} diff --git a/internal/handler/tools.go b/internal/handler/tools.go new file mode 100644 index 0000000..5e95f8d --- /dev/null +++ b/internal/handler/tools.go @@ -0,0 +1,125 @@ +package handler + +import ( + "net/http" + + "cursor2api/internal/tools" + + "github.com/gin-gonic/gin" +) + +// ToolExecuteRequest 工具执行请求 +type ToolExecuteRequest struct { + ToolName string `json:"tool_name"` + Input map[string]interface{} `json:"input"` + WorkDir string `json:"work_dir,omitempty"` +} + +// ToolExecuteResponse 工具执行响应 +type ToolExecuteResponse struct { + Success bool `json:"success"` + Output string `json:"output"` + Error string `json:"error,omitempty"` +} + +// ExecuteTool 执行工具调用(供本地测试使用) +func ExecuteTool(c *gin.Context) { + var req ToolExecuteRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ToolExecuteResponse{ + Success: false, + Error: "无效的请求格式: " + err.Error(), + }) + return + } + + executor := tools.NewExecutor() + if req.WorkDir != "" { + executor.SetWorkDir(req.WorkDir) + } + + output, err := executor.Execute(req.ToolName, req.Input) + if err != nil { + c.JSON(http.StatusOK, ToolExecuteResponse{ + Success: false, + Output: output, + Error: err.Error(), + }) + return + } + + c.JSON(http.StatusOK, ToolExecuteResponse{ + Success: true, + Output: output, + }) +} + +// ListTools 列出可用工具 +func ListTools(c *gin.Context) { + toolList := []tools.ToolDefinition{ + { + Name: "bash", + Description: "执行 bash 命令", + InputSchema: tools.InputSchema{ + Type: "object", + Properties: map[string]tools.Property{ + "command": {Type: "string", Description: "要执行的命令"}, + "cwd": {Type: "string", Description: "工作目录(可选)"}, + }, + Required: []string{"command"}, + }, + }, + { + Name: "read_file", + Description: "读取文件内容", + InputSchema: tools.InputSchema{ + Type: "object", + Properties: map[string]tools.Property{ + "path": {Type: "string", Description: "文件路径"}, + }, + Required: []string{"path"}, + }, + }, + { + Name: "write_file", + Description: "写入文件", + InputSchema: tools.InputSchema{ + Type: "object", + Properties: map[string]tools.Property{ + "path": {Type: "string", Description: "文件路径"}, + "content": {Type: "string", Description: "文件内容"}, + }, + Required: []string{"path", "content"}, + }, + }, + { + Name: "list_dir", + Description: "列出目录内容", + InputSchema: tools.InputSchema{ + Type: "object", + Properties: map[string]tools.Property{ + "path": {Type: "string", Description: "目录路径"}, + }, + Required: []string{"path"}, + }, + }, + { + Name: "edit", + Description: "编辑文件(查找替换)", + InputSchema: tools.InputSchema{ + Type: "object", + Properties: map[string]tools.Property{ + "path": {Type: "string", Description: "文件路径"}, + "old_string": {Type: "string", Description: "要替换的内容"}, + "new_string": {Type: "string", Description: "替换后的内容"}, + "replace_all": {Type: "boolean", Description: "是否替换所有匹配"}, + }, + Required: []string{"path", "old_string", "new_string"}, + }, + }, + } + + c.JSON(http.StatusOK, gin.H{ + "tools": toolList, + }) +} diff --git a/internal/tools/executor.go b/internal/tools/executor.go new file mode 100644 index 0000000..0706d58 --- /dev/null +++ b/internal/tools/executor.go @@ -0,0 +1,260 @@ +package tools + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// Executor 工具执行器 +type Executor struct { + workDir string + allowedDirs []string + timeout time.Duration +} + +// NewExecutor 创建工具执行器 +func NewExecutor() *Executor { + homeDir, _ := os.UserHomeDir() + return &Executor{ + workDir: homeDir, + allowedDirs: []string{homeDir, "/tmp"}, + timeout: 30 * time.Second, + } +} + +// SetWorkDir 设置工作目录 +func (e *Executor) SetWorkDir(dir string) { + e.workDir = dir +} + +// Execute 执行工具调用 +func (e *Executor) Execute(toolName string, input map[string]interface{}) (string, error) { + switch toolName { + case "bash", "run_command": + return e.executeBash(input) + case "read_file": + return e.readFile(input) + case "write_file", "write_to_file": + return e.writeFile(input) + case "list_dir", "list_directory": + return e.listDir(input) + case "edit", "str_replace_editor": + return e.editFile(input) + default: + return "", fmt.Errorf("未知工具: %s", toolName) + } +} + +// executeBash 执行 bash 命令 +func (e *Executor) executeBash(input map[string]interface{}) (string, error) { + command, ok := input["command"].(string) + if !ok { + // 尝试其他字段名 + if cmd, ok := input["CommandLine"].(string); ok { + command = cmd + } else { + return "", fmt.Errorf("缺少 command 参数") + } + } + + // 获取工作目录 + cwd := e.workDir + if dir, ok := input["cwd"].(string); ok && dir != "" { + cwd = dir + } else if dir, ok := input["Cwd"].(string); ok && dir != "" { + cwd = dir + } + + cmd := exec.Command("bash", "-c", command) + cmd.Dir = cwd + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // 设置超时 + done := make(chan error, 1) + go func() { + done <- cmd.Run() + }() + + select { + case err := <-done: + output := stdout.String() + if stderr.Len() > 0 { + if output != "" { + output += "\n" + } + output += stderr.String() + } + if err != nil { + return output, fmt.Errorf("命令执行失败: %v\n%s", err, output) + } + return output, nil + case <-time.After(e.timeout): + cmd.Process.Kill() + return "", fmt.Errorf("命令执行超时 (%v)", e.timeout) + } +} + +// readFile 读取文件 +func (e *Executor) readFile(input map[string]interface{}) (string, error) { + path, ok := input["path"].(string) + if !ok { + if p, ok := input["file_path"].(string); ok { + path = p + } else { + return "", fmt.Errorf("缺少 path 参数") + } + } + + // 处理相对路径 + if !filepath.IsAbs(path) { + path = filepath.Join(e.workDir, path) + } + + content, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("读取文件失败: %v", err) + } + + return string(content), nil +} + +// writeFile 写入文件 +func (e *Executor) writeFile(input map[string]interface{}) (string, error) { + path, ok := input["path"].(string) + if !ok { + if p, ok := input["file_path"].(string); ok { + path = p + } else if p, ok := input["TargetFile"].(string); ok { + path = p + } else { + return "", fmt.Errorf("缺少 path 参数") + } + } + + content, ok := input["content"].(string) + if !ok { + if c, ok := input["CodeContent"].(string); ok { + content = c + } else { + return "", fmt.Errorf("缺少 content 参数") + } + } + + // 处理相对路径 + if !filepath.IsAbs(path) { + path = filepath.Join(e.workDir, path) + } + + // 创建目录 + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return "", fmt.Errorf("创建目录失败: %v", err) + } + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return "", fmt.Errorf("写入文件失败: %v", err) + } + + return fmt.Sprintf("已写入文件: %s (%d 字节)", path, len(content)), nil +} + +// listDir 列出目录内容 +func (e *Executor) listDir(input map[string]interface{}) (string, error) { + path, ok := input["path"].(string) + if !ok { + if p, ok := input["DirectoryPath"].(string); ok { + path = p + } else { + path = e.workDir + } + } + + // 处理相对路径 + if !filepath.IsAbs(path) { + path = filepath.Join(e.workDir, path) + } + + entries, err := os.ReadDir(path) + if err != nil { + return "", fmt.Errorf("读取目录失败: %v", err) + } + + var result strings.Builder + result.WriteString(fmt.Sprintf("目录: %s\n\n", path)) + + for _, entry := range entries { + info, _ := entry.Info() + if entry.IsDir() { + result.WriteString(fmt.Sprintf("[DIR] %s/\n", entry.Name())) + } else { + size := int64(0) + if info != nil { + size = info.Size() + } + result.WriteString(fmt.Sprintf("[FILE] %s (%d bytes)\n", entry.Name(), size)) + } + } + + return result.String(), nil +} + +// editFile 编辑文件(查找替换) +func (e *Executor) editFile(input map[string]interface{}) (string, error) { + path, ok := input["path"].(string) + if !ok { + if p, ok := input["file_path"].(string); ok { + path = p + } else { + return "", fmt.Errorf("缺少 path 参数") + } + } + + oldStr, _ := input["old_string"].(string) + newStr, _ := input["new_string"].(string) + + if oldStr == "" { + return "", fmt.Errorf("缺少 old_string 参数") + } + + // 处理相对路径 + if !filepath.IsAbs(path) { + path = filepath.Join(e.workDir, path) + } + + content, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("读取文件失败: %v", err) + } + + original := string(content) + if !strings.Contains(original, oldStr) { + return "", fmt.Errorf("未找到要替换的内容") + } + + // 替换 + replaceAll := false + if ra, ok := input["replace_all"].(bool); ok { + replaceAll = ra + } + + var modified string + if replaceAll { + modified = strings.ReplaceAll(original, oldStr, newStr) + } else { + modified = strings.Replace(original, oldStr, newStr, 1) + } + + if err := os.WriteFile(path, []byte(modified), 0644); err != nil { + return "", fmt.Errorf("写入文件失败: %v", err) + } + + return fmt.Sprintf("已编辑文件: %s", path), nil +} diff --git a/internal/tools/parser.go b/internal/tools/parser.go new file mode 100644 index 0000000..1675c2a --- /dev/null +++ b/internal/tools/parser.go @@ -0,0 +1,144 @@ +package tools + +import ( + "encoding/json" + "regexp" + "strings" +) + +// Parser 解析 AI 输出中的工具调用 +type Parser struct{} + +// NewParser 创建解析器 +func NewParser() *Parser { + return &Parser{} +} + +// toolCallPattern 匹配工具调用的 JSON 块 +var toolCallPatterns = []*regexp.Regexp{ + // 标准 JSON 块格式 + regexp.MustCompile(`(?s)\s*(\{.*?\})\s*`), + // 代码块格式 + regexp.MustCompile("(?s)```json\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"), + regexp.MustCompile("(?s)```\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"), + // 单行 JSON 格式 + regexp.MustCompile(`(\{"tool"\s*:\s*"[^"]+"\s*,\s*"[^}]+\})`), +} + +// ParseToolCalls 从 AI 输出中解析工具调用 +func (p *Parser) ParseToolCalls(output string) ([]ParsedToolCall, string) { + var calls []ParsedToolCall + remainingText := output + + for _, pattern := range toolCallPatterns { + matches := pattern.FindAllStringSubmatch(output, -1) + for _, match := range matches { + if len(match) < 2 { + continue + } + + jsonStr := match[1] + var rawCall map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &rawCall); err != nil { + continue + } + + // 提取工具名称 + toolName := "" + if name, ok := rawCall["tool"].(string); ok { + toolName = name + } else if name, ok := rawCall["name"].(string); ok { + toolName = name + } + + if toolName == "" { + continue + } + + // 提取输入参数 + input := make(map[string]interface{}) + if inp, ok := rawCall["input"].(map[string]interface{}); ok { + input = inp + } else { + // 其他字段作为输入 + for k, v := range rawCall { + if k != "tool" && k != "name" && k != "type" { + input[k] = v + } + } + } + + calls = append(calls, ParsedToolCall{ + Name: toolName, + Input: input, + }) + + // 从剩余文本中移除已解析的工具调用 + remainingText = strings.Replace(remainingText, match[0], "", 1) + } + } + + // 清理剩余文本 + remainingText = strings.TrimSpace(remainingText) + + return calls, remainingText +} + +// GenerateToolPrompt 生成工具使用的系统提示 +func GenerateToolPrompt(tools []ToolDefinition) string { + if len(tools) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("\n\n## 可用工具\n\n") + sb.WriteString("当你需要执行操作时,请使用以下格式调用工具:\n\n") + sb.WriteString("\n{\"tool\": \"工具名称\", \"参数名\": \"参数值\"}\n\n\n") + sb.WriteString("可用的工具:\n\n") + + for _, tool := range tools { + sb.WriteString("### ") + sb.WriteString(tool.Name) + sb.WriteString("\n") + if tool.Description != "" { + sb.WriteString(tool.Description) + sb.WriteString("\n") + } + sb.WriteString("参数:\n") + + for name, prop := range tool.InputSchema.Properties { + required := "" + for _, r := range tool.InputSchema.Required { + if r == name { + required = " (必需)" + break + } + } + sb.WriteString("- `") + sb.WriteString(name) + sb.WriteString("`") + sb.WriteString(required) + sb.WriteString(": ") + sb.WriteString(prop.Description) + sb.WriteString("\n") + } + sb.WriteString("\n") + } + + sb.WriteString("重要提示:\n") + sb.WriteString("- 每次只调用一个工具\n") + sb.WriteString("- 工具调用必须使用 标签包裹\n") + sb.WriteString("- 等待工具执行结果后再继续\n") + + return sb.String() +} + +// IsToolCallResponse 检查输出是否包含工具调用 +func (p *Parser) IsToolCallResponse(output string) bool { + for _, pattern := range toolCallPatterns { + if pattern.MatchString(output) { + return true + } + } + return false +} diff --git a/internal/tools/types.go b/internal/tools/types.go new file mode 100644 index 0000000..4e3dec8 --- /dev/null +++ b/internal/tools/types.go @@ -0,0 +1,51 @@ +// Package tools 提供工具调用解析和执行功能 +package tools + +// ToolDefinition Anthropic 工具定义格式 +type ToolDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +// InputSchema JSON Schema 格式的输入参数定义 +type InputSchema struct { + Type string `json:"type"` + Properties map[string]Property `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +// Property JSON Schema 属性定义 +type Property struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` +} + +// ToolUse AI 返回的工具调用 +type ToolUse struct { + Type string `json:"type"` // "tool_use" + ID string `json:"id"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// ToolResult 工具执行结果 +type ToolResult struct { + Type string `json:"type"` // "tool_result" + ToolUseID string `json:"tool_use_id"` + Content interface{} `json:"content"` // string 或 []ContentBlock + IsError bool `json:"is_error,omitempty"` +} + +// ContentBlock 内容块(用于 tool_result) +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// ParsedToolCall 从 AI 输出解析的工具调用 +type ParsedToolCall struct { + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +}