mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-05-31 21:39:52 +08:00
refactor: align web search with executor layer patterns
Consolidate web search handler, SSE event generation, stream analysis, and MCP HTTP I/O into the executor layer. Merge the separate kiro_websearch_handler.go back into kiro_executor.go to align with the single-file-per-executor convention. Translator retains only pure data types, detection, and payload transformation. Key changes: - Move SSE construction (search indicators, fallback text, message_start) from translator to executor, consistent with streamToChannel pattern - Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription) from translator to executor alongside other HTTP I/O - Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication) - Centralize MCP endpoint URL via BuildMcpEndpoint in translator - Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache - Thread context.Context through MCP HTTP calls for cancellation support - Thread usage reporter through all web search API call paths - Add token expiry pre-check before MCP/GAR calls - Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic, ContainsWebSearchTool, StripWebSearchTool)
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
|
||||
// Region priority:
|
||||
// 1. auth.Metadata["api_region"] - explicit API region override
|
||||
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
|
||||
// 3. kiroDefaultRegion (us-east-1) - fallback
|
||||
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
|
||||
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
log.Debugf("kiro: using region %s (source: api_region)", r)
|
||||
return r
|
||||
}
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
|
||||
return arnRegion
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
|
||||
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
|
||||
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
|
||||
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
|
||||
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// Determine API region with priority: api_region > profile_arn > region > default
|
||||
region := kiroDefaultRegion
|
||||
regionSource := "default"
|
||||
|
||||
if auth.Metadata != nil {
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
regionSource = "api_region"
|
||||
} else {
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
region = arnRegion
|
||||
regionSource = "profile_arn"
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
|
||||
// Determine API region using shared resolution logic
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
|
||||
// Build endpoint configs for the specified region
|
||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
||||
@@ -520,7 +528,7 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer)
|
||||
// Body is already in Kiro format — pass through directly
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
@@ -640,17 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
@@ -679,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1068,17 +1076,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
@@ -1107,6 +1105,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -4114,6 +4122,238 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Web Search Handler (MCP API)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// fetchToolDescription caching:
|
||||
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
|
||||
// with automatic retry on failure:
|
||||
// - On failure, fetched stays false so subsequent calls will retry
|
||||
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
|
||||
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
|
||||
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
|
||||
var (
|
||||
toolDescMu sync.Mutex
|
||||
toolDescFetched atomic.Bool
|
||||
)
|
||||
|
||||
// fetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
|
||||
// Fast path: already fetched successfully, no lock needed
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
toolDescMu.Lock()
|
||||
defer toolDescMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as callMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
kiroclaude.SetWebSearchDescription(tool.Description)
|
||||
toolDescFetched.Store(true) // success — no more fetches
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
|
||||
}
|
||||
|
||||
// webSearchHandler handles web search requests via Kiro MCP API
|
||||
type webSearchHandler struct {
|
||||
ctx context.Context
|
||||
mcpEndpoint string
|
||||
httpClient *http.Client
|
||||
authToken string
|
||||
auth *cliproxyauth.Auth // for applyDynamicFingerprint
|
||||
authAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// newWebSearchHandler creates a new webSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return &webSearchHandler{
|
||||
ctx: ctx,
|
||||
mcpEndpoint: mcpEndpoint,
|
||||
httpClient: httpClient,
|
||||
authToken: authToken,
|
||||
auth: auth,
|
||||
authAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern.
|
||||
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
|
||||
applyDynamicFingerprint(req, h.auth)
|
||||
|
||||
// 4. AWS SDK identifiers
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.authToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// callMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors.
|
||||
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return nil, h.ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse kiroclaude.McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// webSearchAuthAttrs extracts auth attributes for MCP calls.
|
||||
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
|
||||
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth != nil {
|
||||
return auth.Attributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
@@ -4136,58 +4376,63 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Usage reporting: track web search requests like normal streaming requests
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
|
||||
go func() {
|
||||
var wsErr error
|
||||
defer reporter.trackFailure(ctx, &wsErr)
|
||||
defer close(out)
|
||||
|
||||
// Send message_start event to client
|
||||
messageStartEvent := kiroclaude.SseEvent{
|
||||
Event: "message_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": kiroclaude.GenerateMessageID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": req.Model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": len(req.Payload) / 4,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
|
||||
var totalUsage usage.Detail
|
||||
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
|
||||
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
}
|
||||
var accumulatedOutputLen int
|
||||
defer func() {
|
||||
if wsErr != nil {
|
||||
return // let trackFailure handle failure reporting
|
||||
}
|
||||
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
|
||||
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
|
||||
totalUsage.OutputTokens = 1
|
||||
}
|
||||
reporter.publish(ctx, totalUsage)
|
||||
}()
|
||||
|
||||
// Send message_start event to client (aligned with streamToChannel pattern)
|
||||
// Use payloadRequestedModel to return user's original model alias
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
|
||||
payloadRequestedModel(opts, req.Model),
|
||||
totalUsage.InputTokens,
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}:
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
@@ -4216,14 +4461,10 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
@@ -4255,8 +4496,9 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
@@ -4265,8 +4507,9 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
@@ -4297,12 +4540,14 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
accumulatedOutputLen += len(adjusted)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
accumulatedOutputLen += len(chunk)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -4320,8 +4565,103 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
|
||||
|
||||
// Step 3: Replace restrictive web_search tool description (align with streaming path)
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
// Step 4: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 6: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
@@ -4338,10 +4678,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
@@ -4367,51 +4704,6 @@ func (e *KiroExecutor) callKiroAndBuffer(
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation).
|
||||
// Used in the web search loop where the payload is modified directly in Kiro format.
|
||||
func (e *KiroExecutor) callKiroRawAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
kiroBody []byte,
|
||||
) ([][]byte, error) {
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody))
|
||||
|
||||
kiroFormat := sdktranslator.FromString("kiro")
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
@@ -4428,18 +4720,22 @@ func (e *KiroExecutor) callKiroDirectStream(
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
return e.executeStreamWithRetry(
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var streamErr error
|
||||
defer reporter.trackFailure(ctx, &streamErr)
|
||||
|
||||
stream, streamErr := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
return stream, streamErr
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
|
||||
// with how streamToChannel() uses BuildClaude*Event() functions.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
@@ -4447,182 +4743,14 @@ func (e *KiroExecutor) sendFallbackText(
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
// Generate a simple text summary from search results
|
||||
summary := kiroclaude.FormatSearchContextPrompt(query, searchResults)
|
||||
|
||||
events := []kiroclaude.SseEvent{
|
||||
{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": contentBlockIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": contentBlockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": summary,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": contentBlockIndex,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
|
||||
}
|
||||
}
|
||||
|
||||
// Send message_delta with end_turn and message_stop
|
||||
msgDelta := kiroclaude.SseEvent{
|
||||
Event: "message_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"output_tokens": len(summary) / 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}:
|
||||
}
|
||||
|
||||
msgStop := kiroclaude.SseEvent{
|
||||
Event: "message_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
|
||||
|
||||
// Step 3: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 5: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
|
||||
Reference in New Issue
Block a user