diff --git a/cmd/server/main.go b/cmd/server/main.go
index cd54593..e65429b 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -38,6 +38,10 @@ func main() {
r.POST("/v1/messages/count_tokens", handler.CountTokens)
r.POST("/messages/count_tokens", handler.CountTokens)
+ // 工具相关接口
+ r.GET("/tools", handler.ListTools)
+ r.POST("/tools/execute", handler.ExecuteTool)
+
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
diff --git a/internal/handler/anthropic.go b/internal/handler/anthropic.go
index d961fb6..1cba5d2 100644
--- a/internal/handler/anthropic.go
+++ b/internal/handler/anthropic.go
@@ -8,6 +8,7 @@ import (
"strings"
"cursor2api/internal/browser"
+ "cursor2api/internal/tools"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -17,17 +18,18 @@ import (
// MessagesRequest Anthropic Messages API 请求格式
type MessagesRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens int `json:"max_tokens"`
- Stream bool `json:"stream"`
- System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ MaxTokens int `json:"max_tokens"`
+ Stream bool `json:"stream"`
+ System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
+ Tools []tools.ToolDefinition `json:"tools,omitempty"` // 工具定义
}
// Message 消息格式
type Message struct {
Role string `json:"role"`
- Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock
+ Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock 或 []ToolResult
}
// MessagesResponse Anthropic Messages API 响应格式
@@ -42,10 +44,13 @@ type MessagesResponse struct {
Usage Usage `json:"usage"`
}
-// ContentBlock 内容块
+// ContentBlock 内容块(支持 text 和 tool_use)
type ContentBlock struct {
- Type string `json:"type"`
- Text string `json:"text"`
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ ID string `json:"id,omitempty"` // tool_use
+ Name string `json:"name,omitempty"` // tool_use
+ Input map[string]interface{} `json:"input,omitempty"` // tool_use
}
// Usage token 使用统计
@@ -54,6 +59,17 @@ type Usage struct {
OutputTokens int `json:"output_tokens"`
}
+// 全局工具执行器和解析器
+var (
+ toolExecutor *tools.Executor
+ toolParser *tools.Parser
+)
+
+func init() {
+ toolExecutor = tools.NewExecutor()
+ toolParser = tools.NewParser()
+}
+
// CursorSSEEvent Cursor SSE 事件格式
type CursorSSEEvent struct {
Type string `json:"type"`
@@ -168,8 +184,14 @@ func Messages(c *gin.Context) {
func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
messages := make([]browser.CursorMessage, 0, len(req.Messages)+1)
- // 添加 system 消息
- if sysText := getTextContent(req.System); sysText != "" {
+ // 构建系统消息(包含工具定义)
+ sysText := getTextContent(req.System)
+ if len(req.Tools) > 0 {
+ toolPrompt := tools.GenerateToolPrompt(req.Tools)
+ sysText += toolPrompt
+ }
+
+ if sysText != "" {
messages = append(messages, browser.CursorMessage{
Parts: []browser.CursorPart{{Type: "text", Text: sysText}},
ID: generateID(),
@@ -179,11 +201,14 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
// 添加用户/助手消息
for _, msg := range req.Messages {
- messages = append(messages, browser.CursorMessage{
- Parts: []browser.CursorPart{{Type: "text", Text: getTextContent(msg.Content)}},
- ID: generateID(),
- Role: msg.Role,
- })
+ text := extractMessageText(msg)
+ if text != "" {
+ messages = append(messages, browser.CursorMessage{
+ Parts: []browser.CursorPart{{Type: "text", Text: text}},
+ ID: generateID(),
+ Role: msg.Role,
+ })
+ }
}
return browser.CursorChatRequest{
@@ -199,6 +224,65 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
}
}
+// extractMessageText 从消息中提取文本(处理 tool_result)
+func extractMessageText(msg Message) string {
+ content := msg.Content
+ if content == nil {
+ return ""
+ }
+
+ switch v := content.(type) {
+ case string:
+ return v
+ case []interface{}:
+ var texts []string
+ for _, item := range v {
+ block, ok := item.(map[string]interface{})
+ if !ok {
+ continue
+ }
+
+ blockType, _ := block["type"].(string)
+ switch blockType {
+ case "text":
+ if text, ok := block["text"].(string); ok {
+ texts = append(texts, text)
+ }
+ case "tool_result":
+ // 处理工具结果
+ toolUseID, _ := block["tool_use_id"].(string)
+ resultContent := block["content"]
+ isError, _ := block["is_error"].(bool)
+
+ resultText := ""
+ switch rc := resultContent.(type) {
+ case string:
+ resultText = rc
+ case []interface{}:
+ for _, rcItem := range rc {
+ if rcBlock, ok := rcItem.(map[string]interface{}); ok {
+ if rcBlock["type"] == "text" {
+ if t, ok := rcBlock["text"].(string); ok {
+ resultText += t
+ }
+ }
+ }
+ }
+ }
+
+ prefix := "工具执行结果"
+ if isError {
+ prefix = "工具执行错误"
+ }
+ texts = append(texts, fmt.Sprintf("[%s (ID: %s)]\n%s", prefix, toolUseID, resultText))
+ }
+ }
+ return strings.Join(texts, "\n")
+ default:
+ return fmt.Sprintf("%v", v)
+ }
+}
+
// ================== API 处理 ==================
// handleStream 处理流式请求
@@ -220,8 +304,9 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
c.Writer.WriteString(`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n")
flusher.Flush()
- // 用于累积不完整的 SSE 行
+ // 用于累积完整响应和 SSE 行
var buffer strings.Builder
+ var fullResponse strings.Builder
svc := browser.GetService()
err := svc.SendStreamRequest(cursorReq, func(chunk string) {
@@ -253,6 +338,7 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
}
if event.Type == "text-delta" && event.Delta != "" {
+ fullResponse.WriteString(event.Delta)
deltaJSON, _ := json.Marshal(event.Delta)
c.Writer.WriteString("event: content_block_delta\n")
c.Writer.WriteString(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + string(deltaJSON) + `}}` + "\n\n")
@@ -271,8 +357,34 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
c.Writer.WriteString(`data: {"type":"content_block_stop","index":0}` + "\n\n")
flusher.Flush()
+ // 检测是否有工具调用,决定 stop_reason
+ stopReason := "end_turn"
+ responseText := fullResponse.String()
+ toolCalls, _ := toolParser.ParseToolCalls(responseText)
+
+ if len(toolCalls) > 0 {
+ stopReason = "tool_use"
+ // 发送工具调用块
+ for i, call := range toolCalls {
+ toolID := "toolu_" + generateID()
+ inputJSON, _ := json.Marshal(call.Input)
+
+ c.Writer.WriteString("event: content_block_start\n")
+ c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`+"\n\n", i+1, toolID, call.Name))
+ flusher.Flush()
+
+ c.Writer.WriteString("event: content_block_delta\n")
+ c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":"%s"}}`+"\n\n", i+1, escapeJSON(string(inputJSON))))
+ flusher.Flush()
+
+ c.Writer.WriteString("event: content_block_stop\n")
+ c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", i+1))
+ flusher.Flush()
+ }
+ }
+
c.Writer.WriteString("event: message_delta\n")
- c.Writer.WriteString(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":100}}` + "\n\n")
+ c.Writer.WriteString(fmt.Sprintf(`data: {"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":100}}`+"\n\n", stopReason))
flusher.Flush()
c.Writer.WriteString("event: message_stop\n")
@@ -280,6 +392,16 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
flusher.Flush()
}
+// escapeJSON 转义 JSON 字符串中的特殊字符
+func escapeJSON(s string) string {
+ s = strings.ReplaceAll(s, `\`, `\\`)
+ s = strings.ReplaceAll(s, `"`, `\"`)
+ s = strings.ReplaceAll(s, "\n", `\n`)
+ s = strings.ReplaceAll(s, "\r", `\r`)
+ s = strings.ReplaceAll(s, "\t", `\t`)
+ return s
+}
+
// handleNonStream 处理非流式请求
func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string) {
svc := browser.GetService()
@@ -311,13 +433,61 @@ func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model
}
}
+ responseText := fullText.String()
+ contentBlocks := parseResponseToBlocks(responseText)
+
+ // 确定 stop_reason
+ stopReason := "end_turn"
+ for _, block := range contentBlocks {
+ if block.Type == "tool_use" {
+ stopReason = "tool_use"
+ break
+ }
+ }
+
c.JSON(http.StatusOK, MessagesResponse{
ID: "msg_" + generateID(),
Type: "message",
Role: "assistant",
- Content: []ContentBlock{{Type: "text", Text: fullText.String()}},
+ Content: contentBlocks,
Model: model,
- StopReason: "end_turn",
+ StopReason: stopReason,
Usage: Usage{InputTokens: 100, OutputTokens: 100},
})
}
+
+// parseResponseToBlocks 解析 AI 响应为内容块(检测工具调用)
+func parseResponseToBlocks(text string) []ContentBlock {
+ var blocks []ContentBlock
+
+ // 检测工具调用
+ toolCalls, remainingText := toolParser.ParseToolCalls(text)
+
+ // 添加文本块
+ if remainingText != "" {
+ blocks = append(blocks, ContentBlock{
+ Type: "text",
+ Text: remainingText,
+ })
+ }
+
+ // 添加工具调用块
+ for _, call := range toolCalls {
+ blocks = append(blocks, ContentBlock{
+ Type: "tool_use",
+ ID: "toolu_" + generateID(),
+ Name: call.Name,
+ Input: call.Input,
+ })
+ }
+
+ // 如果没有任何内容,添加空文本块
+ if len(blocks) == 0 {
+ blocks = append(blocks, ContentBlock{
+ Type: "text",
+ Text: text,
+ })
+ }
+
+ return blocks
+}
diff --git a/internal/handler/tools.go b/internal/handler/tools.go
new file mode 100644
index 0000000..5e95f8d
--- /dev/null
+++ b/internal/handler/tools.go
@@ -0,0 +1,125 @@
+package handler
+
+import (
+ "net/http"
+
+ "cursor2api/internal/tools"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ToolExecuteRequest 工具执行请求
+type ToolExecuteRequest struct {
+ ToolName string `json:"tool_name"`
+ Input map[string]interface{} `json:"input"`
+ WorkDir string `json:"work_dir,omitempty"`
+}
+
+// ToolExecuteResponse 工具执行响应
+type ToolExecuteResponse struct {
+ Success bool `json:"success"`
+ Output string `json:"output"`
+ Error string `json:"error,omitempty"`
+}
+
+// ExecuteTool 执行工具调用(供本地测试使用)
+func ExecuteTool(c *gin.Context) {
+ var req ToolExecuteRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, ToolExecuteResponse{
+ Success: false,
+ Error: "无效的请求格式: " + err.Error(),
+ })
+ return
+ }
+
+ executor := tools.NewExecutor()
+ if req.WorkDir != "" {
+ executor.SetWorkDir(req.WorkDir)
+ }
+
+ output, err := executor.Execute(req.ToolName, req.Input)
+ if err != nil {
+ c.JSON(http.StatusOK, ToolExecuteResponse{
+ Success: false,
+ Output: output,
+ Error: err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, ToolExecuteResponse{
+ Success: true,
+ Output: output,
+ })
+}
+
+// ListTools 列出可用工具
+func ListTools(c *gin.Context) {
+ toolList := []tools.ToolDefinition{
+ {
+ Name: "bash",
+ Description: "执行 bash 命令",
+ InputSchema: tools.InputSchema{
+ Type: "object",
+ Properties: map[string]tools.Property{
+ "command": {Type: "string", Description: "要执行的命令"},
+ "cwd": {Type: "string", Description: "工作目录(可选)"},
+ },
+ Required: []string{"command"},
+ },
+ },
+ {
+ Name: "read_file",
+ Description: "读取文件内容",
+ InputSchema: tools.InputSchema{
+ Type: "object",
+ Properties: map[string]tools.Property{
+ "path": {Type: "string", Description: "文件路径"},
+ },
+ Required: []string{"path"},
+ },
+ },
+ {
+ Name: "write_file",
+ Description: "写入文件",
+ InputSchema: tools.InputSchema{
+ Type: "object",
+ Properties: map[string]tools.Property{
+ "path": {Type: "string", Description: "文件路径"},
+ "content": {Type: "string", Description: "文件内容"},
+ },
+ Required: []string{"path", "content"},
+ },
+ },
+ {
+ Name: "list_dir",
+ Description: "列出目录内容",
+ InputSchema: tools.InputSchema{
+ Type: "object",
+ Properties: map[string]tools.Property{
+ "path": {Type: "string", Description: "目录路径"},
+ },
+ Required: []string{"path"},
+ },
+ },
+ {
+ Name: "edit",
+ Description: "编辑文件(查找替换)",
+ InputSchema: tools.InputSchema{
+ Type: "object",
+ Properties: map[string]tools.Property{
+ "path": {Type: "string", Description: "文件路径"},
+ "old_string": {Type: "string", Description: "要替换的内容"},
+ "new_string": {Type: "string", Description: "替换后的内容"},
+ "replace_all": {Type: "boolean", Description: "是否替换所有匹配"},
+ },
+ Required: []string{"path", "old_string", "new_string"},
+ },
+ },
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "tools": toolList,
+ })
+}
diff --git a/internal/tools/executor.go b/internal/tools/executor.go
new file mode 100644
index 0000000..0706d58
--- /dev/null
+++ b/internal/tools/executor.go
@@ -0,0 +1,260 @@
+package tools
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+// Executor 工具执行器
+type Executor struct {
+ workDir string
+ allowedDirs []string
+ timeout time.Duration
+}
+
+// NewExecutor 创建工具执行器
+func NewExecutor() *Executor {
+ homeDir, _ := os.UserHomeDir()
+ return &Executor{
+ workDir: homeDir,
+ allowedDirs: []string{homeDir, "/tmp"},
+ timeout: 30 * time.Second,
+ }
+}
+
+// SetWorkDir 设置工作目录
+func (e *Executor) SetWorkDir(dir string) {
+ e.workDir = dir
+}
+
+// Execute 执行工具调用
+func (e *Executor) Execute(toolName string, input map[string]interface{}) (string, error) {
+ switch toolName {
+ case "bash", "run_command":
+ return e.executeBash(input)
+ case "read_file":
+ return e.readFile(input)
+ case "write_file", "write_to_file":
+ return e.writeFile(input)
+ case "list_dir", "list_directory":
+ return e.listDir(input)
+ case "edit", "str_replace_editor":
+ return e.editFile(input)
+ default:
+ return "", fmt.Errorf("未知工具: %s", toolName)
+ }
+}
+
+// executeBash 执行 bash 命令
+func (e *Executor) executeBash(input map[string]interface{}) (string, error) {
+ command, ok := input["command"].(string)
+ if !ok {
+ // 尝试其他字段名
+ if cmd, ok := input["CommandLine"].(string); ok {
+ command = cmd
+ } else {
+ return "", fmt.Errorf("缺少 command 参数")
+ }
+ }
+
+ // 获取工作目录
+ cwd := e.workDir
+ if dir, ok := input["cwd"].(string); ok && dir != "" {
+ cwd = dir
+ } else if dir, ok := input["Cwd"].(string); ok && dir != "" {
+ cwd = dir
+ }
+
+ cmd := exec.Command("bash", "-c", command)
+ cmd.Dir = cwd
+
+ var stdout, stderr bytes.Buffer
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ // 设置超时
+ done := make(chan error, 1)
+ go func() {
+ done <- cmd.Run()
+ }()
+
+ select {
+ case err := <-done:
+ output := stdout.String()
+ if stderr.Len() > 0 {
+ if output != "" {
+ output += "\n"
+ }
+ output += stderr.String()
+ }
+ if err != nil {
+ return output, fmt.Errorf("命令执行失败: %v\n%s", err, output)
+ }
+ return output, nil
+ case <-time.After(e.timeout):
+ cmd.Process.Kill()
+ return "", fmt.Errorf("命令执行超时 (%v)", e.timeout)
+ }
+}
+
+// readFile 读取文件
+func (e *Executor) readFile(input map[string]interface{}) (string, error) {
+ path, ok := input["path"].(string)
+ if !ok {
+ if p, ok := input["file_path"].(string); ok {
+ path = p
+ } else {
+ return "", fmt.Errorf("缺少 path 参数")
+ }
+ }
+
+ // 处理相对路径
+ if !filepath.IsAbs(path) {
+ path = filepath.Join(e.workDir, path)
+ }
+
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return "", fmt.Errorf("读取文件失败: %v", err)
+ }
+
+ return string(content), nil
+}
+
+// writeFile 写入文件
+func (e *Executor) writeFile(input map[string]interface{}) (string, error) {
+ path, ok := input["path"].(string)
+ if !ok {
+ if p, ok := input["file_path"].(string); ok {
+ path = p
+ } else if p, ok := input["TargetFile"].(string); ok {
+ path = p
+ } else {
+ return "", fmt.Errorf("缺少 path 参数")
+ }
+ }
+
+ content, ok := input["content"].(string)
+ if !ok {
+ if c, ok := input["CodeContent"].(string); ok {
+ content = c
+ } else {
+ return "", fmt.Errorf("缺少 content 参数")
+ }
+ }
+
+ // 处理相对路径
+ if !filepath.IsAbs(path) {
+ path = filepath.Join(e.workDir, path)
+ }
+
+ // 创建目录
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return "", fmt.Errorf("创建目录失败: %v", err)
+ }
+
+ if err := os.WriteFile(path, []byte(content), 0644); err != nil {
+ return "", fmt.Errorf("写入文件失败: %v", err)
+ }
+
+ return fmt.Sprintf("已写入文件: %s (%d 字节)", path, len(content)), nil
+}
+
+// listDir 列出目录内容
+func (e *Executor) listDir(input map[string]interface{}) (string, error) {
+ path, ok := input["path"].(string)
+ if !ok {
+ if p, ok := input["DirectoryPath"].(string); ok {
+ path = p
+ } else {
+ path = e.workDir
+ }
+ }
+
+ // 处理相对路径
+ if !filepath.IsAbs(path) {
+ path = filepath.Join(e.workDir, path)
+ }
+
+ entries, err := os.ReadDir(path)
+ if err != nil {
+ return "", fmt.Errorf("读取目录失败: %v", err)
+ }
+
+ var result strings.Builder
+ result.WriteString(fmt.Sprintf("目录: %s\n\n", path))
+
+ for _, entry := range entries {
+ info, _ := entry.Info()
+ if entry.IsDir() {
+ result.WriteString(fmt.Sprintf("[DIR] %s/\n", entry.Name()))
+ } else {
+ size := int64(0)
+ if info != nil {
+ size = info.Size()
+ }
+ result.WriteString(fmt.Sprintf("[FILE] %s (%d bytes)\n", entry.Name(), size))
+ }
+ }
+
+ return result.String(), nil
+}
+
+// editFile 编辑文件(查找替换)
+func (e *Executor) editFile(input map[string]interface{}) (string, error) {
+ path, ok := input["path"].(string)
+ if !ok {
+ if p, ok := input["file_path"].(string); ok {
+ path = p
+ } else {
+ return "", fmt.Errorf("缺少 path 参数")
+ }
+ }
+
+ oldStr, _ := input["old_string"].(string)
+ newStr, _ := input["new_string"].(string)
+
+ if oldStr == "" {
+ return "", fmt.Errorf("缺少 old_string 参数")
+ }
+
+ // 处理相对路径
+ if !filepath.IsAbs(path) {
+ path = filepath.Join(e.workDir, path)
+ }
+
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return "", fmt.Errorf("读取文件失败: %v", err)
+ }
+
+ original := string(content)
+ if !strings.Contains(original, oldStr) {
+ return "", fmt.Errorf("未找到要替换的内容")
+ }
+
+ // 替换
+ replaceAll := false
+ if ra, ok := input["replace_all"].(bool); ok {
+ replaceAll = ra
+ }
+
+ var modified string
+ if replaceAll {
+ modified = strings.ReplaceAll(original, oldStr, newStr)
+ } else {
+ modified = strings.Replace(original, oldStr, newStr, 1)
+ }
+
+ if err := os.WriteFile(path, []byte(modified), 0644); err != nil {
+ return "", fmt.Errorf("写入文件失败: %v", err)
+ }
+
+ return fmt.Sprintf("已编辑文件: %s", path), nil
+}
diff --git a/internal/tools/parser.go b/internal/tools/parser.go
new file mode 100644
index 0000000..1675c2a
--- /dev/null
+++ b/internal/tools/parser.go
@@ -0,0 +1,144 @@
+package tools
+
+import (
+ "encoding/json"
+ "regexp"
+ "strings"
+)
+
+// Parser 解析 AI 输出中的工具调用
+type Parser struct{}
+
+// NewParser 创建解析器
+func NewParser() *Parser {
+ return &Parser{}
+}
+
+// toolCallPattern 匹配工具调用的 JSON 块
+var toolCallPatterns = []*regexp.Regexp{
+ // 标准 JSON 块格式
+ regexp.MustCompile(`(?s)\s*(\{.*?\})\s*`),
+ // 代码块格式
+ regexp.MustCompile("(?s)```json\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"),
+ regexp.MustCompile("(?s)```\\s*\\n(\\{[^`]*?\"tool\"[^`]*?\\})\\s*\\n```"),
+ // 单行 JSON 格式
+ regexp.MustCompile(`(\{"tool"\s*:\s*"[^"]+"\s*,\s*"[^}]+\})`),
+}
+
+// ParseToolCalls 从 AI 输出中解析工具调用
+func (p *Parser) ParseToolCalls(output string) ([]ParsedToolCall, string) {
+ var calls []ParsedToolCall
+ remainingText := output
+
+ for _, pattern := range toolCallPatterns {
+ matches := pattern.FindAllStringSubmatch(output, -1)
+ for _, match := range matches {
+ if len(match) < 2 {
+ continue
+ }
+
+ jsonStr := match[1]
+ var rawCall map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &rawCall); err != nil {
+ continue
+ }
+
+ // 提取工具名称
+ toolName := ""
+ if name, ok := rawCall["tool"].(string); ok {
+ toolName = name
+ } else if name, ok := rawCall["name"].(string); ok {
+ toolName = name
+ }
+
+ if toolName == "" {
+ continue
+ }
+
+ // 提取输入参数
+ input := make(map[string]interface{})
+ if inp, ok := rawCall["input"].(map[string]interface{}); ok {
+ input = inp
+ } else {
+ // 其他字段作为输入
+ for k, v := range rawCall {
+ if k != "tool" && k != "name" && k != "type" {
+ input[k] = v
+ }
+ }
+ }
+
+ calls = append(calls, ParsedToolCall{
+ Name: toolName,
+ Input: input,
+ })
+
+ // 从剩余文本中移除已解析的工具调用
+ remainingText = strings.Replace(remainingText, match[0], "", 1)
+ }
+ }
+
+ // 清理剩余文本
+ remainingText = strings.TrimSpace(remainingText)
+
+ return calls, remainingText
+}
+
+// GenerateToolPrompt 生成工具使用的系统提示
+func GenerateToolPrompt(tools []ToolDefinition) string {
+ if len(tools) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+ sb.WriteString("\n\n## 可用工具\n\n")
+ sb.WriteString("当你需要执行操作时,请使用以下格式调用工具:\n\n")
+ sb.WriteString("\n{\"tool\": \"工具名称\", \"参数名\": \"参数值\"}\n\n\n")
+ sb.WriteString("可用的工具:\n\n")
+
+ for _, tool := range tools {
+ sb.WriteString("### ")
+ sb.WriteString(tool.Name)
+ sb.WriteString("\n")
+ if tool.Description != "" {
+ sb.WriteString(tool.Description)
+ sb.WriteString("\n")
+ }
+ sb.WriteString("参数:\n")
+
+ for name, prop := range tool.InputSchema.Properties {
+ required := ""
+ for _, r := range tool.InputSchema.Required {
+ if r == name {
+ required = " (必需)"
+ break
+ }
+ }
+ sb.WriteString("- `")
+ sb.WriteString(name)
+ sb.WriteString("`")
+ sb.WriteString(required)
+ sb.WriteString(": ")
+ sb.WriteString(prop.Description)
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n")
+ }
+
+ sb.WriteString("重要提示:\n")
+ sb.WriteString("- 每次只调用一个工具\n")
+ sb.WriteString("- 工具调用必须使用 标签包裹\n")
+ sb.WriteString("- 等待工具执行结果后再继续\n")
+
+ return sb.String()
+}
+
+// IsToolCallResponse 检查输出是否包含工具调用
+func (p *Parser) IsToolCallResponse(output string) bool {
+ for _, pattern := range toolCallPatterns {
+ if pattern.MatchString(output) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/tools/types.go b/internal/tools/types.go
new file mode 100644
index 0000000..4e3dec8
--- /dev/null
+++ b/internal/tools/types.go
@@ -0,0 +1,51 @@
+// Package tools 提供工具调用解析和执行功能
+package tools
+
+// ToolDefinition Anthropic 工具定义格式
+type ToolDefinition struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ InputSchema InputSchema `json:"input_schema"`
+}
+
+// InputSchema JSON Schema 格式的输入参数定义
+type InputSchema struct {
+ Type string `json:"type"`
+ Properties map[string]Property `json:"properties,omitempty"`
+ Required []string `json:"required,omitempty"`
+}
+
+// Property JSON Schema 属性定义
+type Property struct {
+ Type string `json:"type"`
+ Description string `json:"description,omitempty"`
+ Enum []string `json:"enum,omitempty"`
+}
+
+// ToolUse AI 返回的工具调用
+type ToolUse struct {
+ Type string `json:"type"` // "tool_use"
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+}
+
+// ToolResult 工具执行结果
+type ToolResult struct {
+ Type string `json:"type"` // "tool_result"
+ ToolUseID string `json:"tool_use_id"`
+ Content interface{} `json:"content"` // string 或 []ContentBlock
+ IsError bool `json:"is_error,omitempty"`
+}
+
+// ContentBlock 内容块(用于 tool_result)
+type ContentBlock struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+}
+
+// ParsedToolCall 从 AI 输出解析的工具调用
+type ParsedToolCall struct {
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+}