diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index d35570ce..265b4f8c 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -3,6 +3,9 @@ package management import ( "bytes" "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +26,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -2154,9 +2158,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := getOAuthStatus(state); ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) + if statusValue, ok := getOAuthStatus(state); ok { + if statusValue != "" { + // Check for device_code prefix (Kiro AWS Builder ID flow) + // Format: "device_code|verification_url|user_code" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "device_code|") { + parts := strings.SplitN(statusValue, "|", 3) + if len(parts) == 3 { + c.JSON(200, gin.H{ + "status": "device_code", + "verification_url": parts[1], + "user_code": parts[2], + }) + return + } + } + // Check for auth_url prefix (Kiro social auth flow) + // Format: "auth_url|url" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "auth_url|") { + authURL := strings.TrimPrefix(statusValue, "auth_url|") + c.JSON(200, gin.H{ + "status": "auth_url", + "url": authURL, + }) + return + } + // Otherwise treat as error + c.JSON(200, gin.H{"status": "error", "error": statusValue}) } else { c.JSON(200, gin.H{"status": "wait"}) return @@ -2166,3 +2196,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } deleteOAuthStatus(state) } + +const kiroCallbackPort = 9876 + +func (h *Handler) RequestKiroToken(c *gin.Context) { + ctx := context.Background() + + // Get the login method from query parameter (default: aws for device code flow) + method := strings.ToLower(strings.TrimSpace(c.Query("method"))) + if method == "" { + method = "aws" + } + + fmt.Println("Initializing Kiro authentication...") + + state := fmt.Sprintf("kiro-%d", time.Now().UnixNano()) + + switch method { + case "aws", "builder-id": + // AWS Builder ID uses device code flow (no callback needed) + go func() { + ssoClient := kiroauth.NewSSOOIDCClient(h.cfg) + + // Step 1: Register client + fmt.Println("Registering client...") + regResp, err := ssoClient.RegisterClient(ctx) + if err != nil { + log.Errorf("Failed to register client: %v", err) + setOAuthStatus(state, "Failed to register client") + return + } + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + log.Errorf("Failed to start device auth: %v", err) + setOAuthStatus(state, "Failed to start device authorization") + return + } + + // Store the verification URL for the frontend to display + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) + + // Step 3: Poll for token + fmt.Println("Waiting for authorization...") + interval := 5 * time.Second + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + setOAuthStatus(state, "Authorization cancelled") + return + case <-time.After(interval): + tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + log.Errorf("Token creation failed: %v", err) + setOAuthStatus(state, "Token creation failed") + return + } + + // Success! Save the token + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "builder-id", + "provider": "AWS", + "client_id": regResp.ClientID, + "client_secret": regResp.ClientSecret, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + } + + setOAuthStatus(state, "Authorization timed out") + }() + + // Return immediately with the state for polling + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"}) + + case "google", "github": + // Social auth uses protocol handler - for WEB UI we use a callback forwarder + provider := "Google" + if method == "github" { + provider = "Github" + } + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/kiro/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute kiro callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start kiro callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarder(kiroCallbackPort) + } + + socialClient := kiroauth.NewSocialAuthClient(h.cfg) + + // Generate PKCE codes + codeVerifier, codeChallenge, err := generateKiroPKCE() + if err != nil { + log.Errorf("Failed to generate PKCE: %v", err) + setOAuthStatus(state, "Failed to generate PKCE") + return + } + + // Build login URL + authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + "https://prod.us-east-1.auth.desktop.kiro.dev", + provider, + url.QueryEscape(kiroauth.KiroRedirectURI), + codeChallenge, + state, + ) + + // Store auth URL for frontend + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "auth_url|"+authURL) + + // Wait for callback file + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + + for { + if time.Now().After(deadline) { + log.Error("oauth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") + return + } + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + if m["state"] != state { + log.Errorf("State mismatch") + setOAuthStatus(state, "State mismatch") + return + } + code := m["code"] + if code == "" { + log.Error("No authorization code received") + setOAuthStatus(state, "No authorization code received") + return + } + + // Exchange code for tokens + tokenReq := &kiroauth.CreateTokenRequest{ + Code: code, + CodeVerifier: codeVerifier, + RedirectURI: kiroauth.KiroRedirectURI, + } + + tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq) + if errToken != nil { + log.Errorf("Failed to exchange code for tokens: %v", errToken) + setOAuthStatus(state, "Failed to exchange code for tokens") + return + } + + // Save the token + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "profile_arn": tokenResp.ProfileArn, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "social", + "provider": provider, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + time.Sleep(500 * time.Millisecond) + } + }() + + setOAuthStatus(state, "") + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"}) + + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"}) + } +} + +// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth. +func generateKiroPKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 6a6b1b54..91716e36 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" @@ -64,7 +65,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // Modify incoming responses to handle gzip without Content-Encoding // This addresses the same issue as inline handler gzip handling, but at the proxy level proxy.ModifyResponse = func(resp *http.Response) error { - // Only process successful responses + // Log upstream error responses for diagnostics (502, 503, etc.) + // These are NOT proxy connection errors - the upstream responded with an error status + if resp.StatusCode >= 500 { + log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } else if resp.StatusCode >= 400 { + log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } + + // Only process successful responses for gzip decompression if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil } @@ -148,15 +157,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi return nil } - // Error handler for proxy failures + // Error handler for proxy failures with detailed error classification for diagnostics proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - // Check if this is a client-side cancellation (normal behavior) + // Classify the error type for better diagnostics + var errType string + if errors.Is(err, context.DeadlineExceeded) { + errType = "timeout" + } else if errors.Is(err, context.Canceled) { + errType = "canceled" + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + errType = "dial_timeout" + } else if _, ok := err.(net.Error); ok { + errType = "network_error" + } else { + errType = "connection_error" + } + // Don't log as error for context canceled - it's usually client closing connection if errors.Is(err, context.Canceled) { - log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path) + log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path) } else { - log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err) } + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(http.StatusBadGateway) _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) diff --git a/internal/api/server.go b/internal/api/server.go index ade08fef..d702551e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -421,6 +421,18 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + s.engine.GET("/kiro/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if state != "" { + file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state) + _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, oauthCallbackSuccessHTML) + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -586,6 +598,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 61c67886..2ac29bf8 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s ) } -// createToken exchanges the authorization code for tokens. -func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { +// CreateToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal token request: %w", err) @@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP RedirectURI: KiroRedirectURI, } - tokenResp, err := c.createToken(ctx, tokenReq) + tokenResp, err := c.CreateToken(ctx, tokenReq) if err != nil { return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) } diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 4df0cf67..c6282759 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -895,6 +895,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Opus 4.5 via Kiro (2.2x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-5", @@ -906,6 +907,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4", @@ -917,6 +919,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4 via Kiro (1.3x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-haiku-4-5", @@ -928,6 +931,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, // --- Agentic Variants (Optimized for coding agents with chunked writes) --- { @@ -940,6 +944,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-5-agentic", @@ -951,6 +956,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-sonnet-4-agentic", @@ -962,6 +968,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, { ID: "kiro-claude-haiku-4-5-agentic", @@ -973,6 +980,7 @@ func GetKiroModels() []*ModelInfo { Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)", ContextLength: 200000, MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, }, } } diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index f3517bde..8f575df4 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -748,7 +748,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) } return result - case "claude": + case "claude", "kiro", "antigravity": + // Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client result := map[string]any{ "id": model.ID, "object": "model", @@ -763,6 +764,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if model.DisplayName != "" { result["display_name"] = model.DisplayName } + // Add thinking support for Claude Code client + // Claude Code checks for "thinking" field (simple boolean) to enable tab toggle + // Also add "extended_thinking" for detailed budget info + if model.Thinking != nil { + result["thinking"] = true + result["extended_thinking"] = map[string]any{ + "supported": true, + "min": model.Thinking.Min, + "max": model.Thinking.Max, + "zero_allowed": model.Thinking.ZeroAllowed, + "dynamic_allowed": model.Thinking.DynamicAllowed, + } + } return result case "gemini": diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index bff3fb57..cbc5443b 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -36,6 +36,16 @@ const ( kiroAcceptStream = "*/*" kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..." + + // Event Stream frame size constants for boundary protection + // AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes) + // Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4) + minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc) + maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB + + // Event Stream error type constants + ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable + ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed // kiroUserAgent matches amq2api format for User-Agent header kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api @@ -102,6 +112,13 @@ You MUST follow these rules for ALL file operations. Violation causes server tim REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` ) +// Real-time usage estimation configuration +// These control how often usage updates are sent during streaming +var ( + usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters + usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first +) + // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. // This solves the "triple mismatch" problem where different endpoints require matching // Origin and X-Amz-Target header values. @@ -495,7 +512,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } }() - content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body) + content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return resp, err @@ -503,14 +520,14 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Fallback for usage if missing from upstream if usageInfo.TotalTokens == 0 { - if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if enc, encErr := getTokenizer(req.Model); encErr == nil { if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { usageInfo.InputTokens = inp } } if len(content) > 0 { // Use tiktoken for more accurate output token calculation - if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if enc, encErr := getTokenizer(req.Model); encErr == nil { if tokenCount, countErr := enc.Count(content); countErr == nil { usageInfo.OutputTokens = int64(tokenCount) } @@ -530,7 +547,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. reporter.publish(ctx, usageInfo) // Build response in Claude format for Kiro translator - kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo) + // stopReason is extracted from upstream response by parseEventStream + kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil @@ -970,11 +988,40 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { return "claude-sonnet-4.5" } +// EventStreamError represents an Event Stream processing error +type EventStreamError struct { + Type string // "fatal", "malformed" + Message string + Cause error +} + +func (e *EventStreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) +} + +// eventStreamMessage represents a parsed AWS Event Stream message +type eventStreamMessage struct { + EventType string // Event type from headers (e.g., "assistantResponseEvent") + Payload []byte // JSON payload of the message +} + // Kiro API request structs - field order determines JSON key order type kiroPayload struct { ConversationState kiroConversationState `json:"conversationState"` ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *kiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// kiroInferenceConfig contains inference parameters for the Kiro API. +// NOTE: This is an experimental addition - Kiro/Amazon Q API may not support these parameters. +// If the API ignores or rejects these fields, response length is controlled internally by the model. +type kiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` // Maximum output tokens (may be ignored by API) + Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (may be ignored by API) } type kiroConversationState struct { @@ -1058,7 +1105,25 @@ type kiroToolUse struct { // isAgentic parameter enables chunked write optimization prompt for -agentic model variants. // isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). // Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. +// +// max_tokens support: Kiro/Amazon Q API may not officially support max_tokens parameter. +// We attempt to pass it via inferenceConfig.maxTokens, but the API may ignore it. +// Response truncation can be detected via stop_reason == "max_tokens" in the response. func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte { + // Extract max_tokens for potential use in inferenceConfig + var maxTokens int64 + if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + // Normalize origin value for Kiro API compatibility // Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values switch origin { @@ -1118,10 +1183,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // Read budget_tokens if specified - this value comes from: // - Claude API: thinking.budget_tokens directly // - OpenAI API: reasoning_effort -> budget_tokens (low:4000, medium:16000, high:32000) - if bt := thinkingField.Get("budget_tokens"); bt.Exists() && bt.Int() > 0 { + if bt := thinkingField.Get("budget_tokens"); bt.Exists() { budgetTokens = bt.Int() + // If budget_tokens <= 0, disable thinking explicitly + // This allows users to disable thinking by setting budget_tokens to 0 + if budgetTokens <= 0 { + thinkingEnabled = false + log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0") + } + } + if thinkingEnabled { + log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) } - log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) } } @@ -1317,6 +1390,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, }} } + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *kiroInferenceConfig + if maxTokens > 0 || hasTemperature { + inferenceConfig = &kiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + } + // Build payload with correct field order (matches struct definition) payload := kiroPayload{ ConversationState: kiroConversationState{ @@ -1325,7 +1410,8 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, CurrentMessage: currentMessage, History: history, // Now always included (non-nil slice) }, - ProfileArn: profileArn, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, } result, err := json.Marshal(payload) @@ -1485,12 +1571,14 @@ func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssista // NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults // parseEventStream parses AWS Event Stream binary format. -// Extracts text content and tool uses from the response. +// Extracts text content, tool uses, and stop_reason from the response. // Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. -func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) { +// Returns: content, toolUses, usageInfo, stopReason, error +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, string, error) { var content strings.Builder var toolUses []kiroToolUse var usageInfo usage.Detail + var stopReason string // Extracted from upstream response reader := bufio.NewReader(body) // Tool use state tracking for input buffering and deduplication @@ -1498,59 +1586,28 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, var currentToolUse *toolUseState for { - prelude := make([]byte, 8) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + log.Errorf("kiro: parseEventStream error: %v", eventErr) + return content.String(), toolUses, usageInfo, stopReason, eventErr + } + if msg == nil { + // Normal end of stream (EOF) break } - if err != nil { - return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err) - } - totalLen := binary.BigEndian.Uint32(prelude[0:4]) - if totalLen < 8 { - return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen) - } - if totalLen > kiroMaxMessageSize { - return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen) - } - headersLen := binary.BigEndian.Uint32(prelude[4:8]) - - remaining := make([]byte, totalLen-8) - _, err = io.ReadFull(reader, remaining) - if err != nil { - return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) - } - - // Validate headersLen to prevent slice out of bounds - if headersLen+4 > uint32(len(remaining)) { - log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { continue } - // Extract event type from headers - eventType := e.extractEventType(remaining[:headersLen+4]) - - payloadStart := 4 + headersLen - payloadEnd := uint32(len(remaining)) - 4 - if payloadStart >= payloadEnd { - continue - } - - payload := remaining[payloadStart:payloadEnd] - var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { log.Debugf("kiro: skipping malformed event: %v", err) continue } - // DIAGNOSTIC: Log all received event types for debugging - log.Debugf("kiro: parseEventStream received event type: %s", eventType) - if log.IsLevelEnabled(log.TraceLevel) { - log.Tracef("kiro: parseEventStream event payload: %s", string(payload)) - } - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) // These can appear as top-level fields or nested within the event if errType, hasErrType := event["_type"].(string); hasErrType { @@ -1560,7 +1617,7 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, errMsg = msg } log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) } if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { // Generic error event @@ -1573,7 +1630,18 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, } } log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) + } + + // Extract stop_reason from various event formats + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := getString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) } // Handle different event types @@ -1588,6 +1656,15 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if contentText, ok := assistantResp["content"].(string); ok { content.WriteString(contentText) } + // Extract stop_reason from assistantResponseEvent + if sr := getString(assistantResp, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) + } + if sr := getString(assistantResp, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) + } // Extract tool uses from response if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { for _, tuRaw := range toolUsesRaw { @@ -1653,6 +1730,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, if outputTokens, ok := event["outputTokens"].(float64); ok { usageInfo.OutputTokens = int64(outputTokens) } + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := getString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) + } } // Also check nested supplementaryWebLinksEvent @@ -1674,10 +1762,166 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, // Deduplicate all tool uses toolUses = deduplicateToolUses(toolUses) - return cleanedContent, toolUses, usageInfo, nil + // Apply fallback logic for stop_reason if not provided by upstream + // Priority: upstream stopReason > tool_use detection > end_turn default + if stopReason == "" { + if len(toolUses) > 0 { + stopReason = "tool_use" + log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) + } else { + stopReason = "end_turn" + log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit") + } + + return cleanedContent, toolUses, usageInfo, stopReason, nil +} + +// readEventStreamMessage reads and validates a single AWS Event Stream message. +// Returns the parsed message or a structured error for different failure modes. +// This function implements boundary protection and detailed error classification. +// +// AWS Event Stream binary format: +// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4) +// - Headers (variable): header entries +// - Payload (variable): JSON data +// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped) +func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { + // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc) + prelude := make([]byte, 12) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + return nil, nil // Normal end of stream + } + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read prelude", + Cause: err, + } + } + + totalLength := binary.BigEndian.Uint32(prelude[0:4]) + headersLength := binary.BigEndian.Uint32(prelude[4:8]) + // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements) + + // Boundary check: minimum frame size + if totalLength < minEventStreamFrameSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), + } + } + + // Boundary check: maximum message size + if totalLength > maxEventStreamMsgSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), + } + } + + // Boundary check: headers length within message bounds + // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4) + // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc) + if headersLength > totalLength-16 { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), + } + } + + // Read the rest of the message (total - 12 bytes already read) + remaining := make([]byte, totalLength-12) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read message body", + Cause: err, + } + } + + // Extract event type from headers + // Headers start at beginning of 'remaining', length is headersLength + var eventType string + if headersLength > 0 && headersLength <= uint32(len(remaining)) { + eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) + } + + // Calculate payload boundaries + // Payload starts after headers, ends before message_crc (last 4 bytes) + payloadStart := headersLength + payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end + + // Validate payload boundaries + if payloadStart >= payloadEnd { + // No payload, return empty message + return &eventStreamMessage{ + EventType: eventType, + Payload: nil, + }, nil + } + + payload := remaining[payloadStart:payloadEnd] + + return &eventStreamMessage{ + EventType: eventType, + Payload: payload, + }, nil +} + +// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) +func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { + offset := 0 + for offset < len(headers) { + if offset >= len(headers) { + break + } + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + } else { + // Skip other types + break + } + } + return "" } // extractEventType extracts the event type from AWS Event Stream headers +// Note: This is the legacy version that expects headerBytes to include prelude CRC prefix func (e *KiroExecutor) extractEventType(headerBytes []byte) string { // Skip prelude CRC (4 bytes) if len(headerBytes) < 4 { @@ -1737,15 +1981,24 @@ func getString(m map[string]interface{}, key string) string { // buildClaudeResponse constructs a Claude-compatible response. // Supports tool_use blocks when tools are present in the response. -func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte { +// Supports thinking blocks - parses tags and converts to Claude thinking content blocks. +// stopReason is passed from upstream; fallback logic applied if empty. +func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { var contentBlocks []map[string]interface{} - // Add text content if present + // Extract thinking blocks and text from content + // This handles ... tags from Kiro's response if content != "" { - contentBlocks = append(contentBlocks, map[string]interface{}{ - "type": "text", - "text": content, - }) + blocks := e.extractThinkingFromContent(content) + contentBlocks = append(contentBlocks, blocks...) + + // DIAGNOSTIC: Log if thinking blocks were extracted + for _, block := range blocks { + if block["type"] == "thinking" { + thinkingContent := block["thinking"].(string) + log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent)) + } + } } // Add tool_use blocks @@ -1766,10 +2019,18 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs }) } - // Determine stop reason - stopReason := "end_turn" - if len(toolUses) > 0 { - stopReason = "tool_use" + // Use upstream stopReason; apply fallback logic if not provided + if stopReason == "" { + stopReason = "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason) + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)") } response := map[string]interface{}{ @@ -1788,6 +2049,101 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs return result } +// extractThinkingFromContent parses content to extract thinking blocks and text. +// Returns a list of content blocks in the order they appear in the content. +// Handles interleaved thinking and text blocks correctly. +// Based on the streaming implementation's thinking tag handling. +func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]interface{} { + var blocks []map[string]interface{} + + if content == "" { + return blocks + } + + // Check if content contains thinking tags at all + if !strings.Contains(content, thinkingStartTag) { + // No thinking tags, return as plain text + return []map[string]interface{}{ + { + "type": "text", + "text": content, + }, + } + } + + log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content)) + + remaining := content + + for len(remaining) > 0 { + // Look for tag + startIdx := strings.Index(remaining, thinkingStartTag) + + if startIdx == -1 { + // No more thinking tags, add remaining as text + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": remaining, + }) + } + break + } + + // Add text before thinking tag (if any meaningful content) + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": textBefore, + }) + } + } + + // Move past the opening tag + remaining = remaining[startIdx+len(thinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, thinkingEndTag) + + if endIdx == -1 { + // No closing tag found, treat rest as thinking content (incomplete response) + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": remaining, + }) + log.Warnf("kiro: extractThinkingFromContent - missing closing tag") + } + break + } + + // Extract thinking content between tags + thinkContent := remaining[:endIdx] + if strings.TrimSpace(thinkContent) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": thinkContent, + }) + log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent)) + } + + // Move past the closing tag + remaining = remaining[endIdx+len(thinkingEndTag):] + } + + // If no blocks were created (all whitespace), return empty text block + if len(blocks) == 0 { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + return blocks +} + // NOTE: Tool uses are now extracted from API response, not parsed from text @@ -1795,24 +2151,33 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs // Supports tool calling - emits tool_use content blocks when tools are used. // Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. // Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). +// Extracts stop_reason from upstream events when available. func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail - var hasToolUses bool // Track if any tool uses were emitted + var hasToolUses bool // Track if any tool uses were emitted + var upstreamStopReason string // Track stop_reason from upstream events // Tool use state tracking for input buffering and deduplication processedIDs := make(map[string]bool) var currentToolUse *toolUseState - // Duplicate content detection - tracks last content event to filter duplicates - // Based on AIClient-2-API implementation for Kiro - var lastContentEvent string + // NOTE: Duplicate content filtering removed - it was causing legitimate repeated + // content (like consecutive newlines) to be incorrectly filtered out. + // The previous implementation compared lastContentEvent == contentDelta which + // is too aggressive for streaming scenarios. // Streaming token calculation - accumulate content for real-time token counting // Based on AIClient-2-API implementation var accumulatedContent strings.Builder accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + // Real-time usage estimation state + // These track when to send periodic usage updates during streaming + var lastUsageUpdateLen int // Last accumulated content length when usage was sent + var lastUsageUpdateTime = time.Now() // Last time usage update was sent + var lastReportedOutputTokens int64 // Last reported output token count + // Translator param for maintaining tool call state across streaming events // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any @@ -1820,24 +2185,37 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out // Thinking mode state tracking - based on amq2api implementation // Tracks whether we're inside a block and handles partial tags inThinkBlock := false - pendingStartTagChars := 0 // Number of chars that might be start of - pendingEndTagChars := 0 // Number of chars that might be start of - isThinkingBlockOpen := false // Track if thinking content block is open - thinkingBlockIndex := -1 // Index of the thinking content block + pendingStartTagChars := 0 // Number of chars that might be start of + pendingEndTagChars := 0 // Number of chars that might be start of + isThinkingBlockOpen := false // Track if thinking content block is open + thinkingBlockIndex := -1 // Index of the thinking content block // Pre-calculate input tokens from request if possible - if enc, err := tokenizerForModel(model); err == nil { - // Try OpenAI format first, then fall back to raw byte count estimation - if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { - totalUsage.InputTokens = inp + // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback + if enc, err := getTokenizer(model); err == nil { + var inputTokens int64 + var countMethod string + + // Try Claude format first (Kiro uses Claude API format) + if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { + inputTokens = inp + countMethod = "claude" + } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + // Fallback to OpenAI format (for OpenAI-compatible requests) + inputTokens = inp + countMethod = "openai" } else { - // Fallback: estimate from raw request size (roughly 4 chars per token) - totalUsage.InputTokens = int64(len(originalReq) / 4) - if totalUsage.InputTokens == 0 && len(originalReq) > 0 { - totalUsage.InputTokens = 1 + // Final fallback: estimate from raw request size (roughly 4 chars per token) + inputTokens = int64(len(claudeBody) / 4) + if inputTokens == 0 && len(claudeBody) > 0 { + inputTokens = 1 } + countMethod = "estimate" } - log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq)) + + totalUsage.InputTokens = inputTokens + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", + totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) } contentBlockIndex := -1 @@ -1857,9 +2235,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out default: } - prelude := make([]byte, 8) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + // Log the error + log.Errorf("kiro: streamToChannel error: %v", eventErr) + + // Send error to channel for client notification + out <- cliproxyexecutor.StreamChunk{Err: eventErr} + return + } + if msg == nil { + // Normal end of stream (EOF) // Flush any incomplete tool use before ending stream if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] { log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) @@ -1905,46 +2291,64 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out hasToolUses = true currentToolUse = nil } + + // Flush any pending tag characters at EOF + // These are partial tag prefixes that were held back waiting for more data + // Since no more data is coming, output them as regular text + var pendingText string + if pendingStartTagChars > 0 { + pendingText = thinkingStartTag[:pendingStartTagChars] + log.Debugf("kiro: flushing pending start tag chars at EOF: %q", pendingText) + pendingStartTagChars = 0 + } + if pendingEndTagChars > 0 { + pendingText += thinkingEndTag[:pendingEndTagChars] + log.Debugf("kiro: flushing pending end tag chars at EOF: %q", pendingText) + pendingEndTagChars = 0 + } + + // Output pending text if any + if pendingText != "" { + // If we're in a thinking block, output as thinking content + if inThinkBlock && isThinkingBlockOpen { + thinkingEvent := e.buildClaudeThinkingDeltaEvent(pendingText, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } else { + // Output as regular text + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := e.buildClaudeStreamEvent(pendingText, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } break } - if err != nil { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)} - return - } - totalLen := binary.BigEndian.Uint32(prelude[0:4]) - if totalLen < 8 { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)} - return - } - if totalLen > kiroMaxMessageSize { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)} - return - } - headersLen := binary.BigEndian.Uint32(prelude[4:8]) - - remaining := make([]byte, totalLen-8) - _, err = io.ReadFull(reader, remaining) - if err != nil { - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)} - return - } - - // Validate headersLen to prevent slice out of bounds - if headersLen+4 > uint32(len(remaining)) { - log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { continue } - - eventType := e.extractEventType(remaining[:headersLen+4]) - - payloadStart := 4 + headersLen - payloadEnd := uint32(len(remaining)) - 4 - if payloadStart >= payloadEnd { - continue - } - - payload := remaining[payloadStart:payloadEnd] appendAPIResponseChunk(ctx, e.cfg, payload) var event map[string]interface{} @@ -1953,12 +2357,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out continue } - // DIAGNOSTIC: Log all received event types for debugging - log.Debugf("kiro: streamToChannel received event type: %s", eventType) - if log.IsLevelEnabled(log.TraceLevel) { - log.Tracef("kiro: streamToChannel event payload: %s", string(payload)) - } - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) // These can appear as top-level fields or nested within the event if errType, hasErrType := event["_type"].(string); hasErrType { @@ -1986,6 +2384,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out return } + // Extract stop_reason from various event formats (streaming) + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := getString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) + } + // Send message_start on first event if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) @@ -2004,6 +2413,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out log.Debugf("kiro: streamToChannel ignoring followupPrompt event") continue + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := getString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) + } + if sr := getString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) + } + case "assistantResponseEvent": var contentDelta string var toolUses []map[string]interface{} @@ -2012,6 +2432,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out if c, ok := assistantResp["content"].(string); ok { contentDelta = c } + // Extract stop_reason from assistantResponseEvent + if sr := getString(assistantResp, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) + } + if sr := getString(assistantResp, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) + } // Extract tool uses from response if tus, ok := assistantResp["toolUses"].([]interface{}); ok { for _, tuRaw := range tus { @@ -2035,19 +2464,61 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Handle text content with duplicate detection and thinking mode support + // Handle text content with thinking mode support if contentDelta != "" { - // Check for duplicate content - skip if identical to last content event - // Based on AIClient-2-API implementation for Kiro - if contentDelta == lastContentEvent { - log.Debugf("kiro: skipping duplicate content event (len: %d)", len(contentDelta)) - continue - } - lastContentEvent = contentDelta + // NOTE: Duplicate content filtering was removed because it incorrectly + // filtered out legitimate repeated content (like consecutive newlines "\n\n"). + // Streaming naturally can have identical chunks that are valid content. outputLen += len(contentDelta) // Accumulate content for streaming token calculation accumulatedContent.WriteString(contentDelta) + + // Real-time usage estimation: Check if we should send a usage update + // This helps clients track context usage during long thinking sessions + shouldSendUsageUpdate := false + if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { + shouldSendUsageUpdate = true + } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { + shouldSendUsageUpdate = true + } + + if shouldSendUsageUpdate { + // Calculate current output tokens using tiktoken + var currentOutputTokens int64 + if enc, encErr := getTokenizer(model); encErr == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + currentOutputTokens = int64(tokenCount) + } + } + // Fallback to character estimation if tiktoken fails + if currentOutputTokens == 0 { + currentOutputTokens = int64(accumulatedContent.Len() / 4) + if currentOutputTokens == 0 { + currentOutputTokens = 1 + } + } + + // Only send update if token count has changed significantly (at least 10 tokens) + if currentOutputTokens > lastReportedOutputTokens+10 { + // Send ping event with usage information + // This is a non-blocking update that clients can optionally process + pingEvent := e.buildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + lastReportedOutputTokens = currentOutputTokens + log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", + totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) + } + + lastUsageUpdateLen = accumulatedContent.Len() + lastUsageUpdateTime = time.Now() + } // Process content with thinking tag detection - based on amq2api implementation // This handles and tags that may span across chunks @@ -2414,10 +2885,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } // Streaming token calculation - calculate output tokens from accumulated content - // This provides more accurate token counting than simple character division + // Only use local estimation if server didn't provide usage (server-side usage takes priority) if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { // Try to use tiktoken for accurate counting - if enc, err := tokenizerForModel(model); err == nil { + if enc, err := getTokenizer(model); err == nil { if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { totalUsage.OutputTokens = int64(tokenCount) log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) @@ -2446,10 +2917,21 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - // Determine stop reason based on whether tool uses were emitted - stopReason := "end_turn" - if hasToolUses { - stopReason = "tool_use" + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn + stopReason := upstreamStopReason + if stopReason == "" { + if hasToolUses { + stopReason = "tool_use" + log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") + } else { + stopReason = "end_turn" + log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") } // Send message_delta event @@ -2595,6 +3077,24 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte { return []byte("event: message_stop\ndata: " + string(result)) } +// buildClaudePingEventWithUsage creates a ping event with embedded usage information. +// This is used for real-time usage estimation during streaming. +// The usage field is a non-standard extension that clients can optionally process. +// Clients that don't recognize the usage field will simply ignore it. +func (e *KiroExecutor) buildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte { + event := map[string]interface{}{ + "type": "ping", + "usage": map[string]interface{}{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + "estimated": true, // Flag to indicate this is an estimate, not final + }, + } + result, _ := json.Marshal(event) + return []byte("event: ping\ndata: " + string(result)) +} + // buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. // This is used when streaming thinking content wrapped in tags. func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { @@ -2674,10 +3174,21 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c // Also check if expires_at is now in the future with sufficient buffer if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { - // If token expires more than 2 minutes from now, it's still valid - if time.Until(expTime) > 2*time.Minute { + // If token expires more than 5 minutes from now, it's still valid + if time.Until(expTime) > 5*time.Minute { log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) - return auth, nil + // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks + // Without this, shouldRefresh() will return true again in 5 seconds + updated := auth.Clone() + // Set next refresh to 5 minutes before expiry, or at least 30 seconds from now + nextRefresh := expTime.Add(-5 * time.Minute) + minNextRefresh := time.Now().Add(30 * time.Second) + if nextRefresh.Before(minNextRefresh) { + nextRefresh = minNextRefresh + } + updated.NextRefreshAfter = nextRefresh + log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) + return updated, nil } } } @@ -2761,9 +3272,9 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c updated.Attributes["profile_arn"] = tokenData.ProfileArn } - // Set next refresh time to 30 minutes before expiry + // NextRefreshAfter is aligned with RefreshLead (5min) if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { - updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) } log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) @@ -2780,7 +3291,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c var translatorParam any // Pre-calculate input tokens from request if possible - if enc, err := tokenizerForModel(model); err == nil { + if enc, err := getTokenizer(model); err == nil { // Try OpenAI format first, then fall back to raw byte count estimation if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { totalUsage.InputTokens = inp diff --git a/internal/runtime/executor/token_helpers.go b/internal/runtime/executor/token_helpers.go index f4236f9b..3dd2a2b5 100644 --- a/internal/runtime/executor/token_helpers.go +++ b/internal/runtime/executor/token_helpers.go @@ -2,43 +2,107 @@ package executor import ( "fmt" + "regexp" + "strconv" "strings" + "sync" "github.com/tidwall/gjson" "github.com/tiktoken-go/tokenizer" ) +// tokenizerCache stores tokenizer instances to avoid repeated creation +var tokenizerCache sync.Map + +// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models +// where tiktoken may not accurately estimate token counts (e.g., Claude models) +type TokenizerWrapper struct { + Codec tokenizer.Codec + AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates +} + +// Count returns the token count with adjustment factor applied +func (tw *TokenizerWrapper) Count(text string) (int, error) { + count, err := tw.Codec.Count(text) + if err != nil { + return 0, err + } + if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 { + return int(float64(count) * tw.AdjustmentFactor), nil + } + return count, nil +} + +// getTokenizer returns a cached tokenizer for the given model. +// This improves performance by avoiding repeated tokenizer creation. +func getTokenizer(model string) (*TokenizerWrapper, error) { + // Check cache first + if cached, ok := tokenizerCache.Load(model); ok { + return cached.(*TokenizerWrapper), nil + } + + // Cache miss, create new tokenizer + wrapper, err := tokenizerForModel(model) + if err != nil { + return nil, err + } + + // Store in cache (use LoadOrStore to handle race conditions) + actual, _ := tokenizerCache.LoadOrStore(model, wrapper) + return actual.(*TokenizerWrapper), nil +} + // tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id. -func tokenizerForModel(model string) (tokenizer.Codec, error) { +// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate. +func tokenizerForModel(model string) (*TokenizerWrapper, error) { sanitized := strings.ToLower(strings.TrimSpace(model)) + + // Claude models use cl100k_base with 1.1 adjustment factor + // because tiktoken may underestimate Claude's actual token count + if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") { + enc, err := tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil + } + + var enc tokenizer.Codec + var err error + switch { case sanitized == "": - return tokenizer.Get(tokenizer.Cl100kBase) + enc, err = tokenizer.Get(tokenizer.Cl100kBase) case strings.HasPrefix(sanitized, "gpt-5"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-5.1"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-4.1"): - return tokenizer.ForModel(tokenizer.GPT41) + enc, err = tokenizer.ForModel(tokenizer.GPT41) case strings.HasPrefix(sanitized, "gpt-4o"): - return tokenizer.ForModel(tokenizer.GPT4o) + enc, err = tokenizer.ForModel(tokenizer.GPT4o) case strings.HasPrefix(sanitized, "gpt-4"): - return tokenizer.ForModel(tokenizer.GPT4) + enc, err = tokenizer.ForModel(tokenizer.GPT4) case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"): - return tokenizer.ForModel(tokenizer.GPT35Turbo) + enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo) case strings.HasPrefix(sanitized, "o1"): - return tokenizer.ForModel(tokenizer.O1) + enc, err = tokenizer.ForModel(tokenizer.O1) case strings.HasPrefix(sanitized, "o3"): - return tokenizer.ForModel(tokenizer.O3) + enc, err = tokenizer.ForModel(tokenizer.O3) case strings.HasPrefix(sanitized, "o4"): - return tokenizer.ForModel(tokenizer.O4Mini) + enc, err = tokenizer.ForModel(tokenizer.O4Mini) default: - return tokenizer.Get(tokenizer.O200kBase) + enc, err = tokenizer.Get(tokenizer.O200kBase) } + + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil } // countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads. -func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { +func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { if enc == nil { return 0, fmt.Errorf("encoder is nil") } @@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { return 0, nil } + // Count text tokens count, err := enc.Count(joined) if err != nil { return 0, err } - return int64(count), nil + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads. +// This handles Claude's message format with system, messages, and tools. +// Image tokens are estimated based on image dimensions when available. +func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { + if enc == nil { + return 0, fmt.Errorf("encoder is nil") + } + if len(payload) == 0 { + return 0, nil + } + + root := gjson.ParseBytes(payload) + segments := make([]string, 0, 32) + + // Collect system prompt (can be string or array of content blocks) + collectClaudeSystem(root.Get("system"), &segments) + + // Collect messages + collectClaudeMessages(root.Get("messages"), &segments) + + // Collect tools + collectClaudeTools(root.Get("tools"), &segments) + + joined := strings.TrimSpace(strings.Join(segments, "\n")) + if joined == "" { + return 0, nil + } + + // Count text tokens + count, err := enc.Count(joined) + if err != nil { + return 0, err + } + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens +var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`) + +// extractImageTokens extracts image token estimates from placeholder text. +// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count. +func extractImageTokens(text string) int { + matches := imageTokenPattern.FindAllStringSubmatch(text, -1) + total := 0 + for _, match := range matches { + if len(match) > 1 { + if tokens, err := strconv.Atoi(match[1]); err == nil { + total += tokens + } + } + } + return total +} + +// estimateImageTokens calculates estimated tokens for an image based on dimensions. +// Based on Claude's image token calculation: tokens ≈ (width * height) / 750 +// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images). +func estimateImageTokens(width, height float64) int { + if width <= 0 || height <= 0 { + // No valid dimensions, use default estimate (medium-sized image) + return 1000 + } + + tokens := int(width * height / 750) + + // Apply bounds + if tokens < 85 { + tokens = 85 + } + if tokens > 1590 { + tokens = 1590 + } + + return tokens +} + +// collectClaudeSystem extracts text from Claude's system field. +// System can be a string or an array of content blocks. +func collectClaudeSystem(system gjson.Result, segments *[]string) { + if !system.Exists() { + return + } + if system.Type == gjson.String { + addIfNotEmpty(segments, system.String()) + return + } + if system.IsArray() { + system.ForEach(func(_, block gjson.Result) bool { + blockType := block.Get("type").String() + if blockType == "text" || blockType == "" { + addIfNotEmpty(segments, block.Get("text").String()) + } + // Also handle plain string blocks + if block.Type == gjson.String { + addIfNotEmpty(segments, block.String()) + } + return true + }) + } +} + +// collectClaudeMessages extracts text from Claude's messages array. +func collectClaudeMessages(messages gjson.Result, segments *[]string) { + if !messages.Exists() || !messages.IsArray() { + return + } + messages.ForEach(func(_, message gjson.Result) bool { + addIfNotEmpty(segments, message.Get("role").String()) + collectClaudeContent(message.Get("content"), segments) + return true + }) +} + +// collectClaudeContent extracts text from Claude's content field. +// Content can be a string or an array of content blocks. +// For images, estimates token count based on dimensions when available. +func collectClaudeContent(content gjson.Result, segments *[]string) { + if !content.Exists() { + return + } + if content.Type == gjson.String { + addIfNotEmpty(segments, content.String()) + return + } + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + switch partType { + case "text": + addIfNotEmpty(segments, part.Get("text").String()) + case "image": + // Estimate image tokens based on dimensions if available + source := part.Get("source") + if source.Exists() { + width := source.Get("width").Float() + height := source.Get("height").Float() + if width > 0 && height > 0 { + tokens := estimateImageTokens(width, height) + addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens)) + } else { + // No dimensions available, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + } else { + // No source info, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + case "tool_use": + addIfNotEmpty(segments, part.Get("id").String()) + addIfNotEmpty(segments, part.Get("name").String()) + if input := part.Get("input"); input.Exists() { + addIfNotEmpty(segments, input.Raw) + } + case "tool_result": + addIfNotEmpty(segments, part.Get("tool_use_id").String()) + collectClaudeContent(part.Get("content"), segments) + case "thinking": + addIfNotEmpty(segments, part.Get("thinking").String()) + default: + // For unknown types, try to extract any text content + if part.Type == gjson.String { + addIfNotEmpty(segments, part.String()) + } else if part.Type == gjson.JSON { + addIfNotEmpty(segments, part.Raw) + } + } + return true + }) + } +} + +// collectClaudeTools extracts text from Claude's tools array. +func collectClaudeTools(tools gjson.Result, segments *[]string) { + if !tools.Exists() || !tools.IsArray() { + return + } + tools.ForEach(func(_, tool gjson.Result) bool { + addIfNotEmpty(segments, tool.Get("name").String()) + addIfNotEmpty(segments, tool.Get("description").String()) + if inputSchema := tool.Get("input_schema"); inputSchema.Exists() { + addIfNotEmpty(segments, inputSchema.Raw) + } + return true + }) } // buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators. diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 8a57a0cc..be2028e1 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -219,12 +219,12 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ } func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - // v6.1: Intelligent Buffered Streamer strategy - // Enhanced buffering with larger buffer size (16KB) and longer flush interval (120ms). - // Smart flush only when buffer is sufficiently filled (≥50%), dramatically reducing - // flush frequency from ~12.5Hz to ~5-8Hz while maintaining low latency. - writer := bufio.NewWriterSize(c.Writer, 16*1024) // 4KB → 16KB - ticker := time.NewTicker(120 * time.Millisecond) // 80ms → 120ms + // v6.2: Immediate flush strategy for SSE streams + // SSE requires immediate data delivery to prevent client timeouts. + // Previous buffering strategy (16KB buffer, 8KB threshold) caused delays + // because SSE events are typically small (< 1KB), leading to client retries. + writer := bufio.NewWriterSize(c.Writer, 4*1024) // 4KB buffer (smaller for faster flush) + ticker := time.NewTicker(50 * time.Millisecond) // 50ms interval for responsive streaming defer ticker.Stop() var chunkIdx int @@ -238,10 +238,9 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. return case <-ticker.C: - // Smart flush: only flush when buffer has sufficient data (≥50% full) - // This reduces flush frequency while ensuring data flows naturally - buffered := writer.Buffered() - if buffered >= 8*1024 { // At least 8KB (50% of 16KB buffer) + // Flush any buffered data on timer to ensure responsiveness + // For SSE, we flush whenever there's any data to prevent client timeouts + if writer.Buffered() > 0 { if err := writer.Flush(); err != nil { // Error flushing, cancel and return cancel(err) @@ -254,6 +253,7 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. if !ok { // Stream ended, flush remaining data _ = writer.Flush() + flusher.Flush() cancel(nil) return } @@ -263,6 +263,12 @@ func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http. // The handler just needs to forward it without reassembly. if len(chunk) > 0 { _, _ = writer.Write(chunk) + // Immediately flush for first few chunks to establish connection quickly + // This prevents client timeout/retry on slow backends like Kiro + if chunkIdx < 3 { + _ = writer.Flush() + flusher.Flush() + } } chunkIdx++ diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index b95d103b..1eed4b94 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string { } // RefreshLead indicates how soon before expiry a refresh should be attempted. +// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh. func (a *KiroAuthenticator) RefreshLead() *time.Duration { - d := 30 * time.Minute + d := 5 * time.Minute return &d } @@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts "source": "aws-builder-id", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con "source": "google-oauth", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con "source": "github-oauth", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } if tokenData.Email != "" { @@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C "source": "kiro-ide-import", "email": tokenData.Email, }, - NextRefreshAfter: expiresAt.Add(-30 * time.Minute), + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } // Display the email if extracted @@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut updated.Metadata["refresh_token"] = tokenData.RefreshToken updated.Metadata["expires_at"] = tokenData.ExpiresAt updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization - updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + // NextRefreshAfter is aligned with RefreshLead (5min) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) return updated, nil } diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index dc7887e7..eba33bb8 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -40,7 +40,7 @@ type RefreshEvaluator interface { const ( refreshCheckInterval = 5 * time.Second refreshPendingBackoff = time.Minute - refreshFailureBackoff = 5 * time.Minute + refreshFailureBackoff = 1 * time.Minute quotaBackoffBase = time.Second quotaBackoffMax = 30 * time.Minute ) @@ -1471,7 +1471,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { updated.Runtime = auth.Runtime } updated.LastRefreshedAt = now - updated.NextRefreshAfter = time.Time{} + // Preserve NextRefreshAfter set by the Authenticator + // If the Authenticator set a reasonable refresh time, it should not be overwritten + // If the Authenticator did not set it (zero value), shouldRefresh will use default logic updated.LastError = nil updated.UpdatedAt = now _, _ = m.Update(ctx, updated)