Merge branch 'main' into main

This commit is contained in:
Luis Pater
2026-02-15 14:47:52 +08:00
committed by GitHub
12 changed files with 1669 additions and 1334 deletions

View File

@@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
if errToken != nil {
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
return nil
}
if token == "" {
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
return nil
}
if updatedAuth != nil {
@@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
return nil
}
httpReq.Header.Set("Content-Type", "application/json")
@@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
return nil
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
return nil
}
@@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
return nil
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
@@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
return nil
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
return nil
}

View File

@@ -110,7 +110,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
@@ -133,6 +133,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
return resp, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", false)
@@ -209,7 +215,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
}
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
converted := ""
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
@@ -226,7 +237,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
@@ -249,6 +260,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
return nil, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", true)
@@ -349,7 +366,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
var chunks []string
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
@@ -503,8 +525,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
return body
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
return sourceFormat.String() == "openai-response"
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
if sourceFormat.String() == "openai-response" {
return true
}
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
return strings.Contains(baseModel, "codex")
}
// flattenAssistantContent converts assistant message content from array format
@@ -539,6 +565,411 @@ func flattenAssistantContent(body []byte) []byte {
return result
}
func normalizeGitHubCopilotChatTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
if tool.Get("type").String() != "function" {
continue
}
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if input.Exists() {
if input.Type == gjson.String {
return body
}
inputString := input.Raw
if input.Type != gjson.JSON {
inputString = input.String()
}
body, _ = sjson.SetBytes(body, "input", inputString)
return body
}
var parts []string
if system := gjson.GetBytes(body, "system"); system.Exists() {
if text := strings.TrimSpace(collectTextFromNode(system)); text != "" {
parts = append(parts, text)
}
}
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" {
parts = append(parts, text)
}
}
}
body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n"))
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
// Accept OpenAI format (type="function") and Claude format
// (no type field, but has top-level name + input_schema).
if toolType != "" && toolType != "function" {
continue
}
name := tool.Get("name").String()
if name == "" {
name = tool.Get("function.name").String()
}
if name == "" {
continue
}
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
if desc := tool.Get("description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
} else if desc = tool.Get("function.description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
}
if params := tool.Get("parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("function.parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("input_schema"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
}
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
default:
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
}
if toolChoice.Type == gjson.JSON {
choiceType := toolChoice.Get("type").String()
if choiceType == "function" {
name := toolChoice.Get("name").String()
if name == "" {
name = toolChoice.Get("function.name").String()
}
if name != "" {
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
return body
}
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func collectTextFromNode(node gjson.Result) string {
if !node.Exists() {
return ""
}
if node.Type == gjson.String {
return node.String()
}
if node.IsArray() {
var parts []string
for _, item := range node.Array() {
if item.Type == gjson.String {
if text := item.String(); text != "" {
parts = append(parts, text)
}
continue
}
if text := item.Get("text").String(); text != "" {
parts = append(parts, text)
continue
}
if nested := collectTextFromNode(item.Get("content")); nested != "" {
parts = append(parts, nested)
}
}
return strings.Join(parts, "\n")
}
if node.Type == gjson.JSON {
if text := node.Get("text").String(); text != "" {
return text
}
if nested := collectTextFromNode(node.Get("content")); nested != "" {
return nested
}
return node.Raw
}
return node.String()
}
type githubCopilotResponsesStreamToolState struct {
Index int
ID string
Name string
}
type githubCopilotResponsesStreamState struct {
MessageStarted bool
MessageStopSent bool
TextBlockStarted bool
TextBlockIndex int
NextContentIndex int
HasToolUse bool
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
hasToolUse := false
if output := root.Get("output"); output.Exists() && output.IsArray() {
for _, item := range output.Array() {
switch item.Get("type").String() {
case "message":
if content := item.Get("content"); content.Exists() && content.IsArray() {
for _, part := range content.Array() {
if part.Get("type").String() != "output_text" {
continue
}
text := part.Get("text").String()
if text == "" {
continue
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
case "function_call":
hasToolUse = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolID := item.Get("call_id").String()
if toolID == "" {
toolID = item.Get("id").String()
}
toolUse, _ = sjson.Set(toolUse, "id", toolID)
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
argObj := gjson.Parse(args)
if argObj.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
}
}
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
}
}
}
inputTokens := root.Get("usage.input_tokens").Int()
outputTokens := root.Get("usage.output_tokens").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if hasToolUse {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
return out
}
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
if *param == nil {
*param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1,
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
}
}
state := (*param).(*githubCopilotResponsesStreamState)
if !bytes.HasPrefix(line, dataTag) {
return nil
}
payload := bytes.TrimSpace(line[5:])
if bytes.Equal(payload, []byte("[DONE]")) {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4)
ensureMessageStart := func() {
if state.MessageStarted {
return
}
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
state.MessageStarted = true
}
startTextBlockIfNeeded := func() {
if state.TextBlockStarted {
return
}
if state.TextBlockIndex < 0 {
state.TextBlockIndex = state.NextContentIndex
state.NextContentIndex++
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
state.TextBlockStarted = true
}
stopTextBlockIfNeeded := func() {
if !state.TextBlockStarted {
return
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
state.TextBlockStarted = false
state.TextBlockIndex = -1
}
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
if itemID != "" {
if tool, ok := state.ItemIDToTool[itemID]; ok {
return tool
}
}
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
if itemID != "" {
state.ItemIDToTool[itemID] = tool
}
return tool
}
return nil
}
switch event {
case "response.created":
ensureMessageStart()
case "response.output_text.delta":
ensureMessageStart()
startTextBlockIfNeeded()
delta := gjson.GetBytes(payload, "delta").String()
if delta != "" {
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
}
case "response.output_item.added":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
ensureMessageStart()
stopTextBlockIfNeeded()
state.HasToolUse = true
tool := &githubCopilotResponsesStreamToolState{
Index: state.NextContentIndex,
ID: gjson.GetBytes(payload, "item.call_id").String(),
Name: gjson.GetBytes(payload, "item.name").String(),
}
if tool.ID == "" {
tool.ID = gjson.GetBytes(payload, "item.id").String()
}
state.NextContentIndex++
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
state.OutputIndexToTool[outputIndex] = tool
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
state.ItemIDToTool[itemID] = tool
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
case "response.output_item.delta":
item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" {
break
}
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
partial := gjson.GetBytes(payload, "delta").String()
if partial == "" {
partial = item.Get("arguments").String()
}
if partial == "" {
break
}
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
case "response.completed":
ensureMessageStart()
stopTextBlockIfNeeded()
if !state.MessageStopSent {
stopReason := "end_turn"
if state.HasToolUse {
stopReason = "tool_use"
}
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int())
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int())
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true
}
}
return results
}
// isHTTPSuccess checks if the status code indicates success (2xx).
func isHTTPSuccess(statusCode int) bool {
return statusCode >= 200 && statusCode < 300

View File

@@ -1,8 +1,10 @@
package executor
import (
"strings"
"testing"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -52,3 +54,189 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
})
}
}
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
t.Fatal("expected openai-response source to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
t.Fatal("expected codex model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
}
}
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
got := normalizeGitHubCopilotChatTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
}
}
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
got := normalizeGitHubCopilotChatTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
t.Parallel()
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") {
t.Fatalf("input = %q, want merged text", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
t.Parallel()
body := []byte(`{"input":{"foo":"bar"}}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "foo") {
t.Fatalf("input = %q, want stringified object", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("name").String() != "sum" {
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be preserved")
}
}
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 2 {
t.Fatalf("tools len = %d, want 2", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
}
if tools[0].Get("name").String() != "Bash" {
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
}
if tools[0].Get("description").String() != "Run commands" {
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be set from input_schema")
}
if tools[0].Get("parameters.properties.command").Exists() != true {
t.Fatal("expected parameters.properties.command to exist")
}
if tools[1].Get("name").String() != "Read" {
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
}
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function"}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
}
}
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
t.Parallel()
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
}
}
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
// Region priority:
// 1. auth.Metadata["api_region"] - explicit API region override
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
// 3. kiroDefaultRegion (us-east-1) - fallback
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
if auth == nil || auth.Metadata == nil {
return kiroDefaultRegion
}
// Priority 1: Explicit api_region override
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
log.Debugf("kiro: using region %s (source: api_region)", r)
return r
}
// Priority 2: Extract from ProfileARN
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
return arnRegion
}
}
// Note: OIDC "region" field is NOT used for API endpoint
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
// Using OIDC region for API calls causes DNS failures
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
return kiroDefaultRegion
}
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
return kiroEndpointConfigs
}
// Determine API region with priority: api_region > profile_arn > region > default
region := kiroDefaultRegion
regionSource := "default"
if auth.Metadata != nil {
// Priority 1: Explicit api_region override
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
region = r
regionSource = "api_region"
} else {
// Priority 2: Extract from ProfileARN
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
region = arnRegion
regionSource = "profile_arn"
}
}
// Note: OIDC "region" field is NOT used for API endpoint
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
// Using OIDC region for API calls causes DNS failures
}
}
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
// Determine API region using shared resolution logic
region := resolveKiroAPIRegion(auth)
// Build endpoint configs for the specified region
endpointConfigs := buildKiroEndpointConfigs(region)
@@ -520,7 +528,7 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
case "kiro":
// Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer)
// Body is already in Kiro format — pass through directly
log.Debugf("kiro: body already in Kiro format, passing through directly")
return body, false
default:
@@ -640,17 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
rateLimiter.WaitForToken(tokenKey)
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
// Check for pure web_search request
// Route to MCP endpoint instead of normal Kiro API
if kiroclaude.HasWebSearchTool(req.Payload) {
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
// Check if token is expired before making request
// Check if token is expired before making request (covers both normal and web_search paths)
if e.isTokenExpired(accessToken) {
log.Infof("kiro: access token expired, attempting recovery")
@@ -679,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
}
}
// Check for pure web_search request
// Route to MCP endpoint instead of normal Kiro API
if kiroclaude.HasWebSearchTool(req.Payload) {
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("kiro")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
@@ -1025,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
// Build response in Claude format for Kiro translator
// stopReason is extracted from upstream response by parseEventStream
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
requestedModel := payloadRequestedModel(opts, req.Model)
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason)
out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
@@ -1068,17 +1077,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
rateLimiter.WaitForToken(tokenKey)
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
// Check for pure web_search request
// Route to MCP endpoint instead of normal Kiro API
if kiroclaude.HasWebSearchTool(req.Payload) {
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
// Check if token is expired before making request
// Check if token is expired before making request (covers both normal and web_search paths)
if e.isTokenExpired(accessToken) {
log.Infof("kiro: access token expired, attempting recovery before stream request")
@@ -1107,6 +1106,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
}
// Check for pure web_search request
// Route to MCP endpoint instead of normal Kiro API
if kiroclaude.HasWebSearchTool(req.Payload) {
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("kiro")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
@@ -1423,7 +1432,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
// So we always enable thinking parsing for Kiro responses
log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled)
e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled)
e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled)
}(httpResp, thinkingEnabled)
return out, nil
@@ -4114,6 +4123,238 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
return isExpired
}
// ══════════════════════════════════════════════════════════════════════════════
// Web Search Handler (MCP API)
// ══════════════════════════════════════════════════════════════════════════════
// fetchToolDescription caching:
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
// with automatic retry on failure:
// - On failure, fetched stays false so subsequent calls will retry
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
var (
toolDescMu sync.Mutex
toolDescFetched atomic.Bool
)
// fetchToolDescription calls MCP tools/list to get the web_search tool description
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
// The httpClient parameter allows reusing a shared pooled HTTP client.
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
// Fast path: already fetched successfully, no lock needed
if toolDescFetched.Load() {
return
}
toolDescMu.Lock()
defer toolDescMu.Unlock()
// Double-check after acquiring lock
if toolDescFetched.Load() {
return
}
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
if err != nil {
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
return
}
// Reuse same headers as callMcpAPI
handler.setMcpHeaders(req)
resp, err := handler.httpClient.Do(req)
if err != nil {
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil || resp.StatusCode != http.StatusOK {
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
return
}
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
var result struct {
Result *struct {
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
} `json:"result"`
}
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
log.Warnf("kiro/websearch: failed to parse tools/list response")
return
}
for _, tool := range result.Result.Tools {
if tool.Name == "web_search" && tool.Description != "" {
kiroclaude.SetWebSearchDescription(tool.Description)
toolDescFetched.Store(true) // success — no more fetches
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
return
}
}
// web_search tool not found in response
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
}
// webSearchHandler handles web search requests via Kiro MCP API
type webSearchHandler struct {
ctx context.Context
mcpEndpoint string
httpClient *http.Client
authToken string
auth *cliproxyauth.Auth // for applyDynamicFingerprint
authAttrs map[string]string // optional, for custom headers from auth.Attributes
}
// newWebSearchHandler creates a new webSearchHandler.
// If httpClient is nil, a default client with 30s timeout is used.
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
if httpClient == nil {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
return &webSearchHandler{
ctx: ctx,
mcpEndpoint: mcpEndpoint,
httpClient: httpClient,
authToken: authToken,
auth: auth,
authAttrs: authAttrs,
}
}
// setMcpHeaders sets standard MCP API headers on the request,
// aligned with the GAR request pattern.
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
// 1. Content-Type & Accept (aligned with GAR)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
// 2. Kiro-specific headers (aligned with GAR)
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
applyDynamicFingerprint(req, h.auth)
// 4. AWS SDK identifiers
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// 5. Authentication
req.Header.Set("Authorization", "Bearer "+h.authToken)
// 6. Custom headers from auth attributes
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
}
// mcpMaxRetries is the maximum number of retries for MCP API calls.
const mcpMaxRetries = 2
// callMcpAPI calls the Kiro MCP API with the given request.
// Includes retry logic with exponential backoff for retryable errors.
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
}
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
var lastErr error
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<attempt) * time.Second
if backoff > 10*time.Second {
backoff = 10 * time.Second
}
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
select {
case <-h.ctx.Done():
return nil, h.ctx.Err()
case <-time.After(backoff):
}
}
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
h.setMcpHeaders(req)
resp, err := h.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("MCP API request failed: %w", err)
continue // network error → retry
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
continue // read error → retry
}
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
continue
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
}
var mcpResponse kiroclaude.McpResponse
if err := json.Unmarshal(body, &mcpResponse); err != nil {
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
}
if mcpResponse.Error != nil {
code := -1
if mcpResponse.Error.Code != nil {
code = *mcpResponse.Error.Code
}
msg := "Unknown error"
if mcpResponse.Error.Message != nil {
msg = *mcpResponse.Error.Message
}
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
}
return &mcpResponse, nil
}
return nil, lastErr
}
// webSearchAuthAttrs extracts auth attributes for MCP calls.
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
if auth != nil {
return auth.Attributes
}
return nil
}
const maxWebSearchIterations = 5
// handleWebSearchStream handles web_search requests:
@@ -4136,58 +4377,63 @@ func (e *KiroExecutor) handleWebSearchStream(
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
}
// Build MCP endpoint based on region
region := kiroDefaultRegion
if auth != nil && auth.Metadata != nil {
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
region = r
}
}
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
region := resolveKiroAPIRegion(auth)
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
// ── Step 1: tools/list (SYNC) — cache tool description ──
{
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
var authAttrs map[string]string
if auth != nil {
authAttrs = auth.Attributes
}
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
authAttrs := webSearchAuthAttrs(auth)
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
}
// Create output channel
out := make(chan cliproxyexecutor.StreamChunk)
// Usage reporting: track web search requests like normal streaming requests
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
go func() {
var wsErr error
defer reporter.trackFailure(ctx, &wsErr)
defer close(out)
// Send message_start event to client
messageStartEvent := kiroclaude.SseEvent{
Event: "message_start",
Data: map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": kiroclaude.GenerateMessageID(),
"type": "message",
"role": "assistant",
"model": req.Model,
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": len(req.Payload) / 4,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
},
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
var totalUsage usage.Detail
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
totalUsage.InputTokens = inp
} else {
totalUsage.InputTokens = int64(len(req.Payload) / 4)
}
} else {
totalUsage.InputTokens = int64(len(req.Payload) / 4)
}
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
totalUsage.InputTokens = 1
}
var accumulatedOutputLen int
defer func() {
if wsErr != nil {
return // let trackFailure handle failure reporting
}
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
totalUsage.OutputTokens = 1
}
reporter.publish(ctx, totalUsage)
}()
// Send message_start event to client (aligned with streamToChannel pattern)
// Use payloadRequestedModel to return user's original model alias
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
payloadRequestedModel(opts, req.Model),
totalUsage.InputTokens,
)
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}:
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
}
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
@@ -4211,19 +4457,15 @@ func (e *KiroExecutor) handleWebSearchStream(
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
log.Infof("kiro/websearch: search iteration %d/%d — query: %s",
iteration+1, maxWebSearchIterations, currentQuery)
log.Infof("kiro/websearch: search iteration %d/%d",
iteration+1, maxWebSearchIterations)
// MCP search
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
var authAttrs map[string]string
if auth != nil {
authAttrs = auth.Attributes
}
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
authAttrs := webSearchAuthAttrs(auth)
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
var searchResults *kiroclaude.WebSearchResults
if mcpErr != nil {
@@ -4245,7 +4487,7 @@ func (e *KiroExecutor) handleWebSearchStream(
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
case out <- cliproxyexecutor.StreamChunk{Payload: event}:
}
}
contentBlockIndex += 2
@@ -4255,8 +4497,9 @@ func (e *KiroExecutor) handleWebSearchStream(
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
if err != nil {
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
break
return
}
// Call GAR with modified Claude payload (full translation pipeline)
@@ -4265,14 +4508,15 @@ func (e *KiroExecutor) handleWebSearchStream(
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
if kiroErr != nil {
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
break
return
}
// Analyze response
analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v, query: %s, toolUseId: %s",
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse, analysis.WebSearchQuery, analysis.WebSearchToolUseId)
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v",
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse)
if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
// Model wants another search
@@ -4297,12 +4541,14 @@ func (e *KiroExecutor) handleWebSearchStream(
if !shouldForward {
continue
}
accumulatedOutputLen += len(adjusted)
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
}
} else {
accumulatedOutputLen += len(chunk)
select {
case <-ctx.Done():
return
@@ -4320,8 +4566,103 @@ func (e *KiroExecutor) handleWebSearchStream(
return out, nil
}
// handleWebSearch handles web_search requests for non-streaming Execute path.
// Performs MCP search synchronously, injects results into the request payload,
// then calls the normal non-streaming Kiro API path which returns a proper
// Claude JSON response (not SSE chunks).
func (e *KiroExecutor) handleWebSearch(
ctx context.Context,
auth *cliproxyauth.Auth,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
accessToken, profileArn string,
) (cliproxyexecutor.Response, error) {
// Extract search query from Claude Code's web_search tool_use
query := kiroclaude.ExtractSearchQuery(req.Payload)
if query == "" {
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
// Fall through to normal non-streaming path
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
}
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
region := resolveKiroAPIRegion(auth)
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
// Step 1: Fetch/cache tool description (sync)
{
authAttrs := webSearchAuthAttrs(auth)
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
}
// Step 2: Perform MCP search
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
authAttrs := webSearchAuthAttrs(auth)
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
var searchResults *kiroclaude.WebSearchResults
if mcpErr != nil {
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
} else {
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
}
resultCount := 0
if searchResults != nil {
resultCount = len(searchResults.Results)
}
log.Infof("kiro/websearch: non-stream: got %d search results", resultCount)
// Step 3: Replace restrictive web_search tool description (align with streaming path)
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
if simplifyErr != nil {
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
simplifiedPayload = bytes.Clone(req.Payload)
}
// Step 4: Inject search tool_use + tool_result into Claude payload
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
if err != nil {
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
}
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
// to produce a proper Claude JSON response
modifiedReq := req
modifiedReq.Payload = modifiedPayload
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
if err != nil {
return resp, err
}
// Step 6: Inject server_tool_use + web_search_tool_result into response
// so Claude Code can display "Did X searches in Ys"
indicators := []kiroclaude.SearchIndicator{
{
ToolUseID: currentToolUseId,
Query: query,
Results: searchResults,
},
}
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
if injErr != nil {
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
} else {
resp.Payload = injectedPayload
}
return resp, nil
}
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
// Returns the buffered chunks for analysis before forwarding to client.
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
func (e *KiroExecutor) callKiroAndBuffer(
ctx context.Context,
auth *cliproxyauth.Auth,
@@ -4338,10 +4679,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := ""
if auth != nil {
tokenKey = auth.ID
}
tokenKey := getTokenKey(auth)
kiroStream, err := e.executeStreamWithRetry(
ctx, auth, req, opts, accessToken, effectiveProfileArn,
@@ -4367,51 +4705,6 @@ func (e *KiroExecutor) callKiroAndBuffer(
return chunks, nil
}
// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation).
// Used in the web search loop where the payload is modified directly in Kiro format.
func (e *KiroExecutor) callKiroRawAndBuffer(
ctx context.Context,
auth *cliproxyauth.Auth,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
accessToken, profileArn string,
kiroBody []byte,
) ([][]byte, error) {
kiroModelID := e.mapModelToKiro(req.Model)
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := ""
if auth != nil {
tokenKey = auth.ID
}
log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody))
kiroFormat := sdktranslator.FromString("kiro")
kiroStream, err := e.executeStreamWithRetry(
ctx, auth, req, opts, accessToken, effectiveProfileArn,
nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
)
if err != nil {
return nil, err
}
// Buffer all chunks
var chunks [][]byte
for chunk := range kiroStream {
if chunk.Err != nil {
return chunks, chunk.Err
}
if len(chunk.Payload) > 0 {
chunks = append(chunks, bytes.Clone(chunk.Payload))
}
}
log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks))
return chunks, nil
}
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
func (e *KiroExecutor) callKiroDirectStream(
ctx context.Context,
@@ -4428,18 +4721,22 @@ func (e *KiroExecutor) callKiroDirectStream(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := ""
if auth != nil {
tokenKey = auth.ID
}
tokenKey := getTokenKey(auth)
return e.executeStreamWithRetry(
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
var streamErr error
defer reporter.trackFailure(ctx, &streamErr)
stream, streamErr := e.executeStreamWithRetry(
ctx, auth, req, opts, accessToken, effectiveProfileArn,
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
)
return stream, streamErr
}
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
// with how streamToChannel() uses BuildClaude*Event() functions.
func (e *KiroExecutor) sendFallbackText(
ctx context.Context,
out chan<- cliproxyexecutor.StreamChunk,
@@ -4447,182 +4744,14 @@ func (e *KiroExecutor) sendFallbackText(
query string,
searchResults *kiroclaude.WebSearchResults,
) {
// Generate a simple text summary from search results
summary := kiroclaude.FormatSearchContextPrompt(query, searchResults)
events := []kiroclaude.SseEvent{
{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": contentBlockIndex,
"content_block": map[string]interface{}{
"type": "text",
"text": "",
},
},
},
{
Event: "content_block_delta",
Data: map[string]interface{}{
"type": "content_block_delta",
"index": contentBlockIndex,
"delta": map[string]interface{}{
"type": "text_delta",
"text": summary,
},
},
},
{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": contentBlockIndex,
},
},
}
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
for _, event := range events {
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
}
}
// Send message_delta with end_turn and message_stop
msgDelta := kiroclaude.SseEvent{
Event: "message_delta",
Data: map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]interface{}{
"output_tokens": len(summary) / 4,
},
},
}
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}:
}
msgStop := kiroclaude.SseEvent{
Event: "message_stop",
Data: map[string]interface{}{
"type": "message_stop",
},
}
select {
case <-ctx.Done():
return
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}:
}
}
// handleWebSearch handles web_search requests for non-streaming Execute path.
// Performs MCP search synchronously, injects results into the request payload,
// then calls the normal non-streaming Kiro API path which returns a proper
// Claude JSON response (not SSE chunks).
func (e *KiroExecutor) handleWebSearch(
ctx context.Context,
auth *cliproxyauth.Auth,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
accessToken, profileArn string,
) (cliproxyexecutor.Response, error) {
// Extract search query from Claude Code's web_search tool_use
query := kiroclaude.ExtractSearchQuery(req.Payload)
if query == "" {
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
// Fall through to normal non-streaming path
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
}
// Build MCP endpoint based on region
region := kiroDefaultRegion
if auth != nil && auth.Metadata != nil {
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
region = r
}
}
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
// Step 1: Fetch/cache tool description (sync)
{
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
var authAttrs map[string]string
if auth != nil {
authAttrs = auth.Attributes
}
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
}
// Step 2: Perform MCP search
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
var authAttrs map[string]string
if auth != nil {
authAttrs = auth.Attributes
}
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
var searchResults *kiroclaude.WebSearchResults
if mcpErr != nil {
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
} else {
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
}
resultCount := 0
if searchResults != nil {
resultCount = len(searchResults.Results)
}
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
// Step 3: Inject search tool_use + tool_result into Claude payload
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults)
if err != nil {
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
}
// Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry)
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
// to produce a proper Claude JSON response
modifiedReq := req
modifiedReq.Payload = modifiedPayload
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
if err != nil {
return resp, err
}
// Step 5: Inject server_tool_use + web_search_tool_result into response
// so Claude Code can display "Did X searches in Ys"
indicators := []kiroclaude.SearchIndicator{
{
ToolUseID: currentToolUseId,
Query: query,
Results: searchResults,
},
}
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
if injErr != nil {
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
} else {
resp.Payload = injectedPayload
}
return resp, nil
}
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.