mirror of
https://github.com/7836246/cursor2api.git
synced 2026-07-01 02:04:34 +08:00
feat: implement toolify function calling via VM sandbox prompt
- Add virtual machine sandbox prompt injection for tool calls - Parse <vm_write> and <vm_exec> tags from model response - Handle tool_result to avoid infinite loops - Extract tool_result content for model to respond - Optimize streaming with batch flush - Remove unused MCP and tools modules
This commit is contained in:
@@ -8,8 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cursor2api/internal/browser"
|
||||
"cursor2api/internal/config"
|
||||
"cursor2api/internal/tools"
|
||||
"cursor2api/internal/toolify"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -19,18 +18,18 @@ import (
|
||||
|
||||
// MessagesRequest Anthropic Messages API 请求格式
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
|
||||
Tools []tools.ToolDefinition `json:"tools,omitempty"` // 工具定义
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
System interface{} `json:"system,omitempty"` // 可以是 string 或 []ContentBlock
|
||||
Tools []toolify.ToolDefinition `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// Message 消息格式
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock 或 []ToolResult
|
||||
Content interface{} `json:"content"` // 可以是 string 或 []ContentBlock
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages API 响应格式
|
||||
@@ -45,7 +44,7 @@ type MessagesResponse struct {
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// ContentBlock 内容块(支持 text 和 tool_use)
|
||||
// ContentBlock 内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
@@ -60,17 +59,6 @@ type Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// 全局工具执行器和解析器
|
||||
var (
|
||||
toolExecutor *tools.Executor
|
||||
toolParser *tools.Parser
|
||||
)
|
||||
|
||||
func init() {
|
||||
toolExecutor = tools.NewExecutor()
|
||||
toolParser = tools.NewParser()
|
||||
}
|
||||
|
||||
// CursorSSEEvent Cursor SSE 事件格式
|
||||
type CursorSSEEvent struct {
|
||||
Type string `json:"type"`
|
||||
@@ -112,30 +100,11 @@ func getTextContent(content interface{}) string {
|
||||
|
||||
// mapModelName 将模型名称映射到 Cursor 支持的格式
|
||||
func mapModelName(model string) string {
|
||||
model = strings.ToLower(model)
|
||||
|
||||
// 已经是 Cursor 格式
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
// 直接透传模型名称,不做转换
|
||||
if model == "" {
|
||||
return "claude-opus-4-5-20251101"
|
||||
}
|
||||
|
||||
// Claude 模型
|
||||
if strings.Contains(model, "claude") {
|
||||
return "anthropic/claude-sonnet-4.5"
|
||||
}
|
||||
|
||||
// GPT 模型
|
||||
if strings.Contains(model, "gpt") {
|
||||
return "openai/gpt-5-nano"
|
||||
}
|
||||
|
||||
// Gemini 模型
|
||||
if strings.Contains(model, "gemini") {
|
||||
return "google/gemini-2.5-flash"
|
||||
}
|
||||
|
||||
// 默认使用 Claude
|
||||
return "anthropic/claude-sonnet-4.5"
|
||||
return model
|
||||
}
|
||||
|
||||
// ================== 处理器函数 ==================
|
||||
@@ -173,9 +142,9 @@ func Messages(c *gin.Context) {
|
||||
cursorReq := convertToCursor(req)
|
||||
|
||||
if req.Stream {
|
||||
handleStream(c, cursorReq, req.Model)
|
||||
handleStream(c, cursorReq, req.Model, req.Tools)
|
||||
} else {
|
||||
handleNonStream(c, cursorReq, req.Model)
|
||||
handleNonStream(c, cursorReq, req.Model, req.Tools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,13 +154,8 @@ func Messages(c *gin.Context) {
|
||||
func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
messages := make([]browser.CursorMessage, 0, len(req.Messages)+1)
|
||||
|
||||
// 构建系统消息(包含工具定义)
|
||||
// 构建系统消息
|
||||
sysText := getTextContent(req.System)
|
||||
if len(req.Tools) > 0 {
|
||||
toolPrompt := tools.GenerateToolPrompt(req.Tools)
|
||||
sysText += toolPrompt
|
||||
}
|
||||
|
||||
if sysText != "" {
|
||||
messages = append(messages, browser.CursorMessage{
|
||||
Parts: []browser.CursorPart{{Type: "text", Text: sysText}},
|
||||
@@ -200,10 +164,39 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
})
|
||||
}
|
||||
|
||||
// 检测是否有 tool_result(表示工具已执行过)
|
||||
hasToolResult := false
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role == "user" {
|
||||
if content, ok := msg.Content.([]interface{}); ok {
|
||||
for _, item := range content {
|
||||
if block, ok := item.(map[string]interface{}); ok {
|
||||
if block["type"] == "tool_result" {
|
||||
hasToolResult = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 只有第一次调用时才注入工具提示(没有 tool_result)
|
||||
toolPrompt := ""
|
||||
if len(req.Tools) > 0 && !hasToolResult {
|
||||
toolPrompt = toolify.GenerateToolPrompt(req.Tools)
|
||||
}
|
||||
|
||||
// 添加用户/助手消息
|
||||
firstUserMsg := true
|
||||
for _, msg := range req.Messages {
|
||||
text := extractMessageText(msg)
|
||||
if text != "" {
|
||||
// 把工具提示放在第一条用户消息前面
|
||||
if msg.Role == "user" && firstUserMsg && toolPrompt != "" {
|
||||
text = toolPrompt + "\n\n" + text
|
||||
firstUserMsg = false
|
||||
}
|
||||
messages = append(messages, browser.CursorMessage{
|
||||
Parts: []browser.CursorPart{{Type: "text", Text: text}},
|
||||
ID: generateID(),
|
||||
@@ -213,19 +206,15 @@ func convertToCursor(req MessagesRequest) browser.CursorChatRequest {
|
||||
}
|
||||
|
||||
return browser.CursorChatRequest{
|
||||
Context: []browser.CursorContext{{
|
||||
Type: "file",
|
||||
Content: "",
|
||||
FilePath: "/docs/",
|
||||
}},
|
||||
Model: mapModelName(req.Model),
|
||||
ID: generateID(),
|
||||
Messages: messages,
|
||||
Trigger: "submit-message",
|
||||
Tools: req.Tools, // 透传工具定义
|
||||
}
|
||||
}
|
||||
|
||||
// extractMessageText 从消息中提取文本(处理 tool_result)
|
||||
// extractMessageText 从消息中提取文本
|
||||
func extractMessageText(msg Message) string {
|
||||
content := msg.Content
|
||||
if content == nil {
|
||||
@@ -242,40 +231,32 @@ func extractMessageText(msg Message) string {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
switch block["type"] {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
case "tool_result":
|
||||
// 处理工具结果
|
||||
toolUseID, _ := block["tool_use_id"].(string)
|
||||
resultContent := block["content"]
|
||||
isError, _ := block["is_error"].(bool)
|
||||
|
||||
resultText := ""
|
||||
switch rc := resultContent.(type) {
|
||||
case string:
|
||||
resultText = rc
|
||||
case []interface{}:
|
||||
for _, rcItem := range rc {
|
||||
if rcBlock, ok := rcItem.(map[string]interface{}); ok {
|
||||
if rcBlock["type"] == "text" {
|
||||
if t, ok := rcBlock["text"].(string); ok {
|
||||
resultText += t
|
||||
// 提取 tool_result 内容
|
||||
toolID := ""
|
||||
if id, ok := block["tool_use_id"].(string); ok {
|
||||
toolID = id
|
||||
}
|
||||
resultContent := ""
|
||||
if c, ok := block["content"].(string); ok {
|
||||
resultContent = c
|
||||
} else if c, ok := block["content"].([]interface{}); ok {
|
||||
for _, item := range c {
|
||||
if b, ok := item.(map[string]interface{}); ok {
|
||||
if b["type"] == "text" {
|
||||
if t, ok := b["text"].(string); ok {
|
||||
resultContent += t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := "工具执行结果"
|
||||
if isError {
|
||||
prefix = "工具执行错误"
|
||||
}
|
||||
texts = append(texts, fmt.Sprintf("[%s (ID: %s)]\n%s", prefix, toolUseID, resultText))
|
||||
texts = append(texts, fmt.Sprintf("[Tool %s result]: %s", toolID, resultContent))
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
@@ -287,7 +268,7 @@ func extractMessageText(msg Message) string {
|
||||
// ================== API 处理 ==================
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string) {
|
||||
func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string, tools []toolify.ToolDefinition) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -301,13 +282,29 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"message_start","message":{"id":"%s","type":"message","role":"assistant","content":[],"model":"%s","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":100,"output_tokens":0}}}`+"\n\n", id, model))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
var buffer, fullResponse strings.Builder
|
||||
blockIndex := 0
|
||||
toolCount := 0
|
||||
|
||||
// 用于累积完整响应和 SSE 行
|
||||
var buffer strings.Builder
|
||||
var fullResponse strings.Builder
|
||||
// 发送工具调用的辅助函数
|
||||
sendToolCall := func(toolName, argsJSON string) {
|
||||
toolID := fmt.Sprintf("toolu_%d", toolCount)
|
||||
toolCount++
|
||||
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(argsJSON), &args)
|
||||
inputJSON, _ := json.Marshal(args)
|
||||
partialJSONStr, _ := json.Marshal(string(inputJSON))
|
||||
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`+"\n\n", blockIndex, toolID, toolName))
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%s}}`+"\n\n", blockIndex, string(partialJSONStr)))
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", blockIndex))
|
||||
blockIndex++
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
svc := browser.GetService()
|
||||
err := svc.SendStreamRequest(cursorReq, func(chunk string) {
|
||||
@@ -315,7 +312,6 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
content := buffer.String()
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
// 保留最后一个可能不完整的行
|
||||
if !strings.HasSuffix(content, "\n") && len(lines) > 0 {
|
||||
buffer.Reset()
|
||||
buffer.WriteString(lines[len(lines)-1])
|
||||
@@ -340,10 +336,7 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
|
||||
if event.Type == "text-delta" && event.Delta != "" {
|
||||
fullResponse.WriteString(event.Delta)
|
||||
deltaJSON, _ := json.Marshal(event.Delta)
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + string(deltaJSON) + `}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
// 只累积,不实时解析(避免重复执行)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -352,109 +345,44 @@ func handleStream(c *gin.Context, cursorReq browser.CursorChatRequest, model str
|
||||
c.Writer.WriteString("event: error\n")
|
||||
c.Writer.WriteString(`data: {"type":"error","error":{"message":"` + err.Error() + `"}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_stop","index":0}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
// 检测是否有工具调用,决定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
// 解析完整响应
|
||||
responseText := fullResponse.String()
|
||||
toolCalls, _ := toolParser.ParseToolCalls(responseText)
|
||||
toolCalls, cleanText := toolify.ParseToolCalls(responseText)
|
||||
|
||||
// 如果没有工具调用,检查是否是拒绝响应,自动执行(需要配置开启)
|
||||
cfg := config.Get()
|
||||
if len(toolCalls) == 0 && cfg.Browser.AutoExecute && tools.DetectRefusal(responseText) {
|
||||
if cmd := tools.ExtractCommandFromRefusal(responseText); cmd != "" {
|
||||
// 自动执行提取的命令
|
||||
output, execErr := toolExecutor.Execute("bash", map[string]interface{}{
|
||||
"command": cmd,
|
||||
})
|
||||
// 发送干净文本
|
||||
if cleanText != "" {
|
||||
textJSON, _ := json.Marshal(cleanText)
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`+"\n\n", blockIndex))
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":%s}}`+"\n\n", blockIndex, string(textJSON)))
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", blockIndex))
|
||||
flusher.Flush()
|
||||
blockIndex++
|
||||
}
|
||||
|
||||
resultText := output
|
||||
isError := false
|
||||
if execErr != nil {
|
||||
resultText = execErr.Error()
|
||||
isError = true
|
||||
}
|
||||
|
||||
// 清除之前发送的 AI 拒绝文本,发送一个新的干净的文本块
|
||||
// 注意:SSE 流已经发送了拒绝文本,我们只能追加结果
|
||||
// 发送执行结果作为新的文本块(替代之前的拒绝内容)
|
||||
statusEmoji := "✅"
|
||||
statusText := "执行成功"
|
||||
if isError {
|
||||
statusEmoji = "❌"
|
||||
statusText = "执行失败"
|
||||
}
|
||||
|
||||
// 简洁的结果输出,不显示 AI 的废话
|
||||
var resultMsg string
|
||||
if resultText == "" {
|
||||
resultMsg = fmt.Sprintf("\n\n---\n%s **%s**\n```bash\n%s\n```", statusEmoji, statusText, cmd)
|
||||
} else {
|
||||
resultMsg = fmt.Sprintf("\n\n---\n%s **%s**\n```bash\n%s\n```\n输出:\n```\n%s\n```", statusEmoji, statusText, cmd, resultText)
|
||||
}
|
||||
resultJSON, _ := json.Marshal(resultMsg)
|
||||
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":` + string(resultJSON) + `}}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(`data: {"type":"content_block_stop","index":1}` + "\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
// 设置 stop_reason 为 end_turn 而不是 tool_use
|
||||
stopReason = "end_turn"
|
||||
}
|
||||
} else if len(toolCalls) > 0 {
|
||||
// 发送工具调用
|
||||
stopReason := "end_turn"
|
||||
if len(toolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
// 发送工具调用块
|
||||
for i, call := range toolCalls {
|
||||
toolID := "toolu_" + generateID()
|
||||
inputJSON, _ := json.Marshal(call.Input)
|
||||
|
||||
c.Writer.WriteString("event: content_block_start\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`+"\n\n", i+1, toolID, call.Name))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_delta\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":"%s"}}`+"\n\n", i+1, escapeJSON(string(inputJSON))))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: content_block_stop\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`+"\n\n", i+1))
|
||||
flusher.Flush()
|
||||
for _, call := range toolCalls {
|
||||
sendToolCall(call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.WriteString("event: message_delta\n")
|
||||
c.Writer.WriteString(fmt.Sprintf(`data: {"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":100}}`+"\n\n", stopReason))
|
||||
flusher.Flush()
|
||||
|
||||
c.Writer.WriteString("event: message_stop\n")
|
||||
c.Writer.WriteString(`data: {"type":"message_stop"}` + "\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// escapeJSON 转义 JSON 字符串中的特殊字符
|
||||
func escapeJSON(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
s = strings.ReplaceAll(s, "\n", `\n`)
|
||||
s = strings.ReplaceAll(s, "\r", `\r`)
|
||||
s = strings.ReplaceAll(s, "\t", `\t`)
|
||||
return s
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string) {
|
||||
func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model string, tools []toolify.ToolDefinition) {
|
||||
svc := browser.GetService()
|
||||
result, err := svc.SendRequest(cursorReq)
|
||||
if err != nil {
|
||||
@@ -485,15 +413,32 @@ func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model
|
||||
}
|
||||
|
||||
responseText := fullText.String()
|
||||
contentBlocks := parseResponseToBlocks(responseText, nil)
|
||||
|
||||
// 确定 stop_reason
|
||||
var contentBlocks []ContentBlock
|
||||
stopReason := "end_turn"
|
||||
for _, block := range contentBlocks {
|
||||
if block.Type == "tool_use" {
|
||||
|
||||
// 检测工具调用
|
||||
if len(tools) > 0 {
|
||||
toolCalls, cleanText := toolify.ParseToolCalls(responseText)
|
||||
if len(toolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
break
|
||||
if cleanText != "" {
|
||||
contentBlocks = append(contentBlocks, ContentBlock{Type: "text", Text: cleanText})
|
||||
}
|
||||
for _, call := range toolCalls {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &args)
|
||||
contentBlocks = append(contentBlocks, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_" + call.ID,
|
||||
Name: call.Function.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
contentBlocks = append(contentBlocks, ContentBlock{Type: "text", Text: responseText})
|
||||
}
|
||||
} else {
|
||||
contentBlocks = append(contentBlocks, ContentBlock{Type: "text", Text: responseText})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MessagesResponse{
|
||||
@@ -506,83 +451,3 @@ func handleNonStream(c *gin.Context, cursorReq browser.CursorChatRequest, model
|
||||
Usage: Usage{InputTokens: 100, OutputTokens: 100},
|
||||
})
|
||||
}
|
||||
|
||||
// parseResponseToBlocks 解析 AI 响应为内容块(检测工具调用)
|
||||
func parseResponseToBlocks(text string, userMessages []string) []ContentBlock {
|
||||
var blocks []ContentBlock
|
||||
|
||||
// 检测工具调用
|
||||
toolCalls, remainingText := toolParser.ParseToolCalls(text)
|
||||
|
||||
// 如果没有工具调用,检查是否是拒绝响应(需要配置开启)
|
||||
cfg := config.Get()
|
||||
if len(toolCalls) == 0 && cfg.Browser.AutoExecute && tools.DetectRefusal(text) {
|
||||
// 尝试从拒绝响应中提取命令并自动执行
|
||||
if cmd := tools.ExtractCommandFromRefusal(text); cmd != "" {
|
||||
// 自动执行提取的命令
|
||||
output, err := toolExecutor.Execute("bash", map[string]interface{}{
|
||||
"command": cmd,
|
||||
})
|
||||
|
||||
resultText := output
|
||||
isError := false
|
||||
if err != nil {
|
||||
resultText = err.Error()
|
||||
isError = true
|
||||
}
|
||||
|
||||
// 返回工具使用和结果
|
||||
toolID := "toolu_" + generateID()
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: "正在执行命令...",
|
||||
})
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: "bash",
|
||||
Input: map[string]interface{}{"command": cmd},
|
||||
})
|
||||
|
||||
// 添加执行结果说明
|
||||
statusText := "✅ 命令执行成功"
|
||||
if isError {
|
||||
statusText = "❌ 命令执行失败"
|
||||
}
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("\n\n%s:\n```\n%s\n```", statusText, resultText),
|
||||
})
|
||||
|
||||
return blocks
|
||||
}
|
||||
}
|
||||
|
||||
// 添加文本块
|
||||
if remainingText != "" {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: remainingText,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加工具调用块
|
||||
for _, call := range toolCalls {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_" + generateID(),
|
||||
Name: call.Name,
|
||||
Input: call.Input,
|
||||
})
|
||||
}
|
||||
|
||||
// 如果没有任何内容,添加空文本块
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, ContentBlock{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"cursor2api/internal/tools"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ToolExecuteRequest 工具执行请求
|
||||
type ToolExecuteRequest struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
WorkDir string `json:"work_dir,omitempty"`
|
||||
}
|
||||
|
||||
// ToolExecuteResponse 工具执行响应
|
||||
type ToolExecuteResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteTool 执行工具调用(供本地测试使用)
|
||||
func ExecuteTool(c *gin.Context) {
|
||||
var req ToolExecuteRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ToolExecuteResponse{
|
||||
Success: false,
|
||||
Error: "无效的请求格式: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
executor := tools.NewExecutor()
|
||||
if req.WorkDir != "" {
|
||||
executor.SetWorkDir(req.WorkDir)
|
||||
}
|
||||
|
||||
output, err := executor.Execute(req.ToolName, req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, ToolExecuteResponse{
|
||||
Success: false,
|
||||
Output: output,
|
||||
Error: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ToolExecuteResponse{
|
||||
Success: true,
|
||||
Output: output,
|
||||
})
|
||||
}
|
||||
|
||||
// ListTools 列出可用工具
|
||||
func ListTools(c *gin.Context) {
|
||||
toolList := []tools.ToolDefinition{
|
||||
{
|
||||
Name: "bash",
|
||||
Description: "执行 bash 命令",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"command": {Type: "string", Description: "要执行的命令"},
|
||||
"cwd": {Type: "string", Description: "工作目录(可选)"},
|
||||
},
|
||||
Required: []string{"command"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "read_file",
|
||||
Description: "读取文件内容",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "write_file",
|
||||
Description: "写入文件",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
"content": {Type: "string", Description: "文件内容"},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list_dir",
|
||||
Description: "列出目录内容",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "目录路径"},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "edit",
|
||||
Description: "编辑文件(查找替换)",
|
||||
InputSchema: tools.InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.Property{
|
||||
"path": {Type: "string", Description: "文件路径"},
|
||||
"old_string": {Type: "string", Description: "要替换的内容"},
|
||||
"new_string": {Type: "string", Description: "替换后的内容"},
|
||||
"replace_all": {Type: "boolean", Description: "是否替换所有匹配"},
|
||||
},
|
||||
Required: []string{"path", "old_string", "new_string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tools": toolList,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user