From e4c957078c8eeaadddb2336e471e1b6940bd7142 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 17 May 2026 01:02:35 +0800 Subject: [PATCH] feat(auth): add OAuth2 support for xAI with PKCE and token persistence - Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation. --- cmd/server/main.go | 4 + config.example.yaml | 7 +- .../api/handlers/management/auth_files.go | 180 ++++++ .../api/handlers/management/oauth_sessions.go | 2 + internal/api/server.go | 15 + internal/auth/xai/pkce.go | 20 + internal/auth/xai/token.go | 104 ++++ internal/auth/xai/types.go | 72 +++ internal/auth/xai/xai.go | 304 ++++++++++ internal/auth/xai/xai_auth_test.go | 105 ++++ internal/cmd/auth_manager.go | 3 +- internal/cmd/xai_login.go | 44 ++ internal/config/config.go | 2 +- internal/registry/model_definitions.go | 10 + internal/registry/model_updater.go | 2 + internal/registry/models/models.json | 107 +++- internal/runtime/executor/xai_executor.go | 570 ++++++++++++++++++ .../runtime/executor/xai_executor_test.go | 138 +++++ internal/tui/oauth_tab.go | 3 + sdk/auth/refresh_registry.go | 1 + sdk/auth/xai.go | 282 +++++++++ sdk/auth/xai_test.go | 37 ++ sdk/cliproxy/service.go | 6 + .../service_xai_executor_binding_test.go | 36 ++ 24 files changed, 2050 insertions(+), 4 deletions(-) create mode 100644 internal/auth/xai/pkce.go create mode 100644 internal/auth/xai/token.go create mode 100644 internal/auth/xai/types.go create mode 100644 internal/auth/xai/xai.go create mode 100644 internal/auth/xai/xai_auth_test.go create mode 100644 internal/cmd/xai_login.go create mode 100644 internal/runtime/executor/xai_executor.go create mode 100644 internal/runtime/executor/xai_executor_test.go create mode 100644 sdk/auth/xai.go create mode 100644 sdk/auth/xai_test.go create mode 100644 sdk/cliproxy/service_xai_executor_binding_test.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 1a5688eb9..392fd4bcc 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -182,6 +182,7 @@ func main() { var oauthCallbackPort int var antigravityLogin bool var kimiLogin bool + var xaiLogin bool var projectID string var vertexImport string var vertexImportPrefix string @@ -203,6 +204,7 @@ func main() { flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth") + flag.BoolVar(&xaiLogin, "xai-login", false, "Login to xAI using OAuth") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -656,6 +658,8 @@ func main() { cmd.DoClaudeLogin(cfg, options) } else if kimiLogin { cmd.DoKimiLogin(cfg, options) + } else if xaiLogin { + cmd.DoXAILogin(cfg, options) } else { // In cloud deploy mode without config file, just wait for shutdown signals if isCloudDeploy && !configFileExists { diff --git a/config.example.yaml b/config.example.yaml index d49c378cb..464f97eaf 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -345,7 +345,7 @@ nonstream-keepalive-interval: 0 # Global OAuth model name aliases (per channel) # These aliases rename model IDs for both model listing and request routing. -# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi. +# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai. # NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. # NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping # client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps @@ -375,6 +375,9 @@ nonstream-keepalive-interval: 0 # kimi: # - name: "kimi-k2.5" # alias: "k2.5" +# xai: +# - name: "grok-4.3" +# alias: "grok-latest" # OAuth provider excluded models # oauth-excluded-models: @@ -395,6 +398,8 @@ nonstream-keepalive-interval: 0 # - "gpt-5-codex-mini" # kimi: # - "kimi-k2-thinking" +# xai: +# - "grok-3-mini" # Optional payload configuration # payload: diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 775a31a49..3fe6e678b 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -27,6 +27,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi" + xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" @@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } +func (h *Handler) RequestXAIToken(c *gin.Context) { + ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) + + fmt.Println("Initializing xAI authentication...") + + pkceCodes, errPKCE := xaiauth.GeneratePKCECodes() + if errPKCE != nil { + log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"}) + return + } + + state, errState := misc.GenerateRandomState() + if errState != nil { + log.Errorf("Failed to generate state parameter: %v", errState) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) + return + } + + nonce, errNonce := misc.GenerateRandomState() + if errNonce != nil { + log.Errorf("Failed to generate nonce parameter: %v", errNonce) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"}) + return + } + + authSvc := xaiauth.NewXAIAuth(h.cfg) + discovery, errDiscover := authSvc.Discover(ctx) + if errDiscover != nil { + log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"}) + return + } + + redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath) + authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{ + AuthorizationEndpoint: discovery.AuthorizationEndpoint, + RedirectURI: redirectURI, + CodeChallenge: pkceCodes.CodeChallenge, + State: state, + Nonce: nonce, + }) + if errAuthURL != nil { + log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"}) + return + } + + RegisterOAuthSession(state, "xai") + + isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/xai/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute xai callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + var errStart error + if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start xai callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder) + } + + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + var authCode string + for { + if !IsOAuthSessionPending(state, "xai") { + return + } + if time.Now().After(deadline) { + log.Error("xai oauth flow timed out") + SetOAuthSessionError(state, "OAuth flow timed out") + return + } + if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { + var payload map[string]string + _ = json.Unmarshal(data, &payload) + _ = os.Remove(waitFile) + if errStr := strings.TrimSpace(payload["error"]); errStr != "" { + log.Errorf("xAI authentication failed: %s", errStr) + SetOAuthSessionError(state, "Authentication failed: "+errStr) + return + } + if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { + log.Errorf("xAI authentication failed: state mismatch") + SetOAuthSessionError(state, "Authentication failed: state mismatch") + return + } + authCode = strings.TrimSpace(payload["code"]) + if authCode == "" { + log.Error("xAI authentication failed: code not found") + SetOAuthSessionError(state, "Authentication failed: code not found") + return + } + break + } + time.Sleep(500 * time.Millisecond) + } + + bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint) + if errExchange != nil { + log.Errorf("Failed to exchange xAI token: %v", errExchange) + SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange)) + return + } + + tokenStorage := authSvc.CreateTokenStorage(bundle) + if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" { + log.Error("xAI token exchange returned empty access token") + SetOAuthSessionError(state, "Failed to exchange token") + return + } + + fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject) + label := strings.TrimSpace(tokenStorage.Email) + if label == "" { + label = "xAI" + } + + metadata := map[string]any{ + "type": "xai", + "access_token": tokenStorage.AccessToken, + "refresh_token": tokenStorage.RefreshToken, + "id_token": tokenStorage.IDToken, + "token_type": tokenStorage.TokenType, + "expires_in": tokenStorage.ExpiresIn, + "expired": tokenStorage.Expire, + "last_refresh": tokenStorage.LastRefresh, + "base_url": tokenStorage.BaseURL, + "redirect_uri": tokenStorage.RedirectURI, + "token_endpoint": tokenStorage.TokenEndpoint, + "auth_kind": "oauth", + } + if tokenStorage.Email != "" { + metadata["email"] = tokenStorage.Email + } + if tokenStorage.Subject != "" { + metadata["sub"] = tokenStorage.Subject + } + + record := &coreauth.Auth{ + ID: fileName, + Provider: "xai", + FileName: fileName, + Label: label, + Storage: tokenStorage, + Metadata: metadata, + Attributes: map[string]string{ + "auth_kind": "oauth", + "base_url": tokenStorage.BaseURL, + }, + } + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save xAI token to file: %v", errSave) + SetOAuthSessionError(state, "Failed to save token to file") + return + } + + CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("xai") + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + fmt.Println("You can now use xAI services through this CLI") + }() + + c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) +} + func (h *Handler) RequestKimiToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index 56273019d..a74f7d560 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) { return "gemini", nil case "antigravity", "anti-gravity": return "antigravity", nil + case "xai", "x-ai", "x.ai", "grok": + return "xai", nil default: return "", errUnsupportedOAuthFlow } diff --git a/internal/api/server.go b/internal/api/server.go index 492061a47..499c4acb5 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -484,6 +484,20 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + s.engine.GET("/xai/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } + if state != "" { + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, oauthCallbackSuccessHTML) + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken) + mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken) mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } diff --git a/internal/auth/xai/pkce.go b/internal/auth/xai/pkce.go new file mode 100644 index 000000000..54d2c23df --- /dev/null +++ b/internal/auth/xai/pkce.go @@ -0,0 +1,20 @@ +package xai + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow. +func GeneratePKCECodes() (*PKCECodes, error) { + bytes := make([]byte, 96) + if _, err := rand.Read(bytes); err != nil { + return nil, fmt.Errorf("xai pkce: generate verifier: %w", err) + } + verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes) + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) + return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil +} diff --git a/internal/auth/xai/token.go b/internal/auth/xai/token.go new file mode 100644 index 000000000..183d0f379 --- /dev/null +++ b/internal/auth/xai/token.go @@ -0,0 +1,104 @@ +package xai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + log "github.com/sirupsen/logrus" +) + +// TokenStorage stores xAI OAuth credentials on disk. +type TokenStorage struct { + Type string `json:"type"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + Expire string `json:"expired,omitempty"` + LastRefresh string `json:"last_refresh,omitempty"` + Email string `json:"email,omitempty"` + Subject string `json:"sub,omitempty"` + BaseURL string `json:"base_url,omitempty"` + RedirectURI string `json:"redirect_uri,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` + AuthKind string `json:"auth_kind,omitempty"` + + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows the token store to merge status fields before saving. +func (ts *TokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta +} + +// SaveTokenToFile writes xAI credentials to a JSON auth file. +func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "xai" + ts.AuthKind = "oauth" + if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil { + return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll) + } + file, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("xai token storage: create token file: %w", err) + } + defer func() { + if errClose := file.Close(); errClose != nil { + log.Errorf("xai token storage: close token file error: %v", errClose) + } + }() + + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("xai token storage: merge metadata: %w", errMerge) + } + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err = encoder.Encode(data); err != nil { + return fmt.Errorf("xai token storage: write token file: %w", err) + } + return nil +} + +// CredentialFileName returns the filename used for xAI credentials. +func CredentialFileName(email, subject string) string { + email = sanitizeFileSegment(email) + if email != "" { + return fmt.Sprintf("xai-%s.json", email) + } + subject = sanitizeFileSegment(subject) + if subject != "" { + return fmt.Sprintf("xai-%s.json", subject) + } + return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli()) +} + +func sanitizeFileSegment(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + var b strings.Builder + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= 'A' && r <= 'Z': + b.WriteRune(r) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '@' || r == '.' || r == '_' || r == '-': + b.WriteRune(r) + default: + b.WriteRune('-') + } + } + return strings.Trim(b.String(), "-") +} diff --git a/internal/auth/xai/types.go b/internal/auth/xai/types.go new file mode 100644 index 000000000..0a2b82081 --- /dev/null +++ b/internal/auth/xai/types.go @@ -0,0 +1,72 @@ +// Package xai provides OAuth2 authentication helpers for xAI Grok. +package xai + +import "time" + +const ( + // DefaultAPIBaseURL is the default xAI Responses API base URL. + DefaultAPIBaseURL = "https://api.x.ai/v1" + // Issuer is xAI's OAuth issuer. + Issuer = "https://auth.x.ai" + // DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints. + DiscoveryURL = Issuer + "/.well-known/openid-configuration" + // ClientID is the public xAI Grok CLI OAuth client ID. + ClientID = "b1a00492-073a-47ea-816f-4c329264a828" + // Scope is the OAuth scope set required for xAI API access. + Scope = "openid profile email offline_access grok-cli:access api:access" + // RedirectHost is the loopback host used by xAI OAuth. + RedirectHost = "127.0.0.1" + // CallbackPort is the preferred loopback callback port. + CallbackPort = 56121 + // RedirectPath is the loopback callback path registered by the xAI client. + RedirectPath = "/callback" +) + +var refreshLead = 5 * time.Minute + +// RefreshLead returns the refresh lead time for xAI OAuth credentials. +func RefreshLead() time.Duration { + return refreshLead +} + +// PKCECodes holds the PKCE verifier/challenge pair. +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +// AuthorizeURLParams contains the values used to build the xAI OAuth URL. +type AuthorizeURLParams struct { + AuthorizationEndpoint string + RedirectURI string + CodeChallenge string + State string + Nonce string +} + +// Discovery contains OAuth endpoints resolved from xAI OIDC discovery. +type Discovery struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` +} + +// TokenData holds xAI OAuth token data. +type TokenData struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + Expire string `json:"expired,omitempty"` + Email string `json:"email,omitempty"` + Subject string `json:"sub,omitempty"` +} + +// AuthBundle aggregates token data and OAuth metadata for persistence. +type AuthBundle struct { + TokenData TokenData + LastRefresh string + BaseURL string + RedirectURI string + TokenEndpoint string +} diff --git a/internal/auth/xai/xai.go b/internal/auth/xai/xai.go new file mode 100644 index 000000000..aa34c8732 --- /dev/null +++ b/internal/auth/xai/xai.go @@ -0,0 +1,304 @@ +package xai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + log "github.com/sirupsen/logrus" +) + +// XAIAuth performs xAI OAuth discovery, token exchange, and refresh. +type XAIAuth struct { + httpClient *http.Client +} + +// NewXAIAuth creates an xAI OAuth helper using config proxy settings. +func NewXAIAuth(cfg *config.Config) *XAIAuth { + return NewXAIAuthWithProxyURL(cfg, "") +} + +// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL. +func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth { + effectiveProxyURL := strings.TrimSpace(proxyURL) + var sdkCfg config.SDKConfig + if cfg != nil { + sdkCfg = cfg.SDKConfig + if effectiveProxyURL == "" { + effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL) + } + } + sdkCfg.ProxyURL = effectiveProxyURL + return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})} +} + +// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery. +func ValidateOAuthEndpoint(rawURL string, field string) (string, error) { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "", fmt.Errorf("xai discovery %s is empty", field) + } + parsed, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err) + } + if parsed.Scheme != "https" { + return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL) + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") { + return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host) + } + return rawURL, nil +} + +// BuildAuthorizeURL builds the browser URL for xAI OAuth. +func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) { + endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint") + if err != nil { + return "", err + } + if strings.TrimSpace(params.RedirectURI) == "" { + return "", fmt.Errorf("xai authorize URL: redirect URI is required") + } + if strings.TrimSpace(params.CodeChallenge) == "" { + return "", fmt.Errorf("xai authorize URL: code challenge is required") + } + if strings.TrimSpace(params.State) == "" { + return "", fmt.Errorf("xai authorize URL: state is required") + } + if strings.TrimSpace(params.Nonce) == "" { + return "", fmt.Errorf("xai authorize URL: nonce is required") + } + values := url.Values{ + "response_type": {"code"}, + "client_id": {ClientID}, + "redirect_uri": {strings.TrimSpace(params.RedirectURI)}, + "scope": {Scope}, + "code_challenge": {strings.TrimSpace(params.CodeChallenge)}, + "code_challenge_method": {"S256"}, + "state": {strings.TrimSpace(params.State)}, + "nonce": {strings.TrimSpace(params.Nonce)}, + "plan": {"generic"}, + "referrer": {"cli-proxy-api"}, + } + return endpoint + "?" + values.Encode(), nil +} + +// Discover resolves xAI OAuth endpoints through OIDC discovery. +func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) { + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil) + if err != nil { + return nil, fmt.Errorf("xai discovery: create request: %w", err) + } + req.Header.Set("Accept", "application/json") + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("xai discovery: request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("xai discovery: close response body error: %v", errClose) + } + }() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("xai discovery: read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + } + if err = json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("xai discovery: parse response: %w", err) + } + authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint") + if err != nil { + return nil, err + } + tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint") + if err != nil { + return nil, err + } + return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil +} + +// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens. +func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) { + if pkceCodes == nil { + return nil, fmt.Errorf("xai token exchange: PKCE codes are required") + } + if strings.TrimSpace(code) == "" { + return nil, fmt.Errorf("xai token exchange: authorization code is required") + } + if strings.TrimSpace(redirectURI) == "" { + return nil, fmt.Errorf("xai token exchange: redirect URI is required") + } + if strings.TrimSpace(tokenEndpoint) == "" { + discovery, errDiscover := a.Discover(ctx) + if errDiscover != nil { + return nil, errDiscover + } + tokenEndpoint = discovery.TokenEndpoint + } + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {strings.TrimSpace(code)}, + "redirect_uri": {strings.TrimSpace(redirectURI)}, + "client_id": {ClientID}, + "code_verifier": {pkceCodes.CodeVerifier}, + } + tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form) + if err != nil { + return nil, err + } + return &AuthBundle{ + TokenData: *tokenData, + LastRefresh: time.Now().UTC().Format(time.RFC3339), + BaseURL: DefaultAPIBaseURL, + RedirectURI: strings.TrimSpace(redirectURI), + TokenEndpoint: strings.TrimSpace(tokenEndpoint), + }, nil +} + +// RefreshTokens refreshes an xAI access token. +func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) { + if strings.TrimSpace(refreshToken) == "" { + return nil, fmt.Errorf("xai token refresh: refresh token is required") + } + if strings.TrimSpace(tokenEndpoint) == "" { + discovery, errDiscover := a.Discover(ctx) + if errDiscover != nil { + return nil, errDiscover + } + tokenEndpoint = discovery.TokenEndpoint + } + form := url.Values{ + "grant_type": {"refresh_token"}, + "client_id": {ClientID}, + "refresh_token": {strings.TrimSpace(refreshToken)}, + } + return a.postTokenForm(ctx, tokenEndpoint, form) +} + +func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) { + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("xai token request: create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("xai token request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("xai token request: close response body error: %v", errClose) + } + }() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("xai token response: read body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err = json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("xai token response: parse body: %w", err) + } + if strings.TrimSpace(payload.AccessToken) == "" { + return nil, fmt.Errorf("xai token response missing access_token") + } + email, subject := parseJWTIdentity(payload.IDToken) + return &TokenData{ + AccessToken: strings.TrimSpace(payload.AccessToken), + RefreshToken: strings.TrimSpace(payload.RefreshToken), + IDToken: strings.TrimSpace(payload.IDToken), + TokenType: strings.TrimSpace(payload.TokenType), + ExpiresIn: payload.ExpiresIn, + Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339), + Email: email, + Subject: subject, + }, nil +} + +// CreateTokenStorage converts an auth bundle into persistable storage. +func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage { + if bundle == nil { + return nil + } + return &TokenStorage{ + Type: "xai", + AccessToken: bundle.TokenData.AccessToken, + RefreshToken: bundle.TokenData.RefreshToken, + IDToken: bundle.TokenData.IDToken, + TokenType: bundle.TokenData.TokenType, + ExpiresIn: bundle.TokenData.ExpiresIn, + Expire: bundle.TokenData.Expire, + LastRefresh: bundle.LastRefresh, + Email: strings.TrimSpace(bundle.TokenData.Email), + Subject: bundle.TokenData.Subject, + BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL), + RedirectURI: bundle.RedirectURI, + TokenEndpoint: bundle.TokenEndpoint, + AuthKind: "oauth", + } +} + +func parseJWTIdentity(token string) (email string, subject string) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return "", "" + } + payload := parts[1] + payload += strings.Repeat("=", (4-len(payload)%4)%4) + raw, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return "", "" + } + var claims map[string]any + if err = json.Unmarshal(raw, &claims); err != nil { + return "", "" + } + if v, ok := claims["email"].(string); ok { + email = strings.TrimSpace(v) + } + if v, ok := claims["sub"].(string); ok { + subject = strings.TrimSpace(v) + } + return email, subject +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/internal/auth/xai/xai_auth_test.go b/internal/auth/xai/xai_auth_test.go new file mode 100644 index 000000000..80f2ef222 --- /dev/null +++ b/internal/auth/xai/xai_auth_test.go @@ -0,0 +1,105 @@ +package xai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) { + authURL, err := BuildAuthorizeURL(AuthorizeURLParams{ + AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize", + RedirectURI: "http://127.0.0.1:56121/callback", + CodeChallenge: "challenge", + State: "state-123", + Nonce: "nonce-123", + }) + if err != nil { + t.Fatalf("BuildAuthorizeURL() error = %v", err) + } + + parsed, errParse := url.Parse(authURL) + if errParse != nil { + t.Fatalf("parse authorize URL: %v", errParse) + } + if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" { + t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path) + } + + query := parsed.Query() + want := map[string]string{ + "response_type": "code", + "client_id": ClientID, + "redirect_uri": "http://127.0.0.1:56121/callback", + "scope": Scope, + "code_challenge": "challenge", + "code_challenge_method": "S256", + "state": "state-123", + "nonce": "nonce-123", + "plan": "generic", + "referrer": "cli-proxy-api", + } + for key, value := range want { + if got := query.Get(key); got != value { + t.Fatalf("%s = %q, want %q", key, got, value) + } + } +} + +func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) { + if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil { + t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err) + } + if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil { + t.Fatal("expected non-HTTPS endpoint to be rejected") + } + if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil { + t.Fatal("expected non-xAI endpoint to be rejected") + } +} + +func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) { + var gotForm url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", r.Method) + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") { + t.Fatalf("Content-Type = %q, want form", got) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm() error = %v", err) + } + gotForm = r.PostForm + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access", + "refresh_token": "new-refresh", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer server.Close() + + auth := NewXAIAuth(nil) + tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL) + if err != nil { + t.Fatalf("RefreshTokens() error = %v", err) + } + if tokenData.AccessToken != "new-access" { + t.Fatalf("access token = %q, want new-access", tokenData.AccessToken) + } + if gotForm.Get("grant_type") != "refresh_token" { + t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type")) + } + if gotForm.Get("client_id") != ClientID { + t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID) + } + if gotForm.Get("refresh_token") != "old-refresh" { + t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token")) + } +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index 7896a7023..a5882e654 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,7 +6,7 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, Antigravity, and Kimi providers. +// Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance @@ -18,6 +18,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), sdkAuth.NewKimiAuthenticator(), + sdkAuth.NewXAIAuthenticator(), ) return manager } diff --git a/internal/cmd/xai_login.go b/internal/cmd/xai_login.go new file mode 100644 index 000000000..c03490439 --- /dev/null +++ b/internal/cmd/xai_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoXAILogin triggers the OAuth flow for the xAI provider and saves tokens. +func DoXAILogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + Prompt: promptFn, + } + + record, savedPath, err := manager.Login(context.Background(), "xai", cfg, authOpts) + if err != nil { + log.Errorf("xAI authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("xAI authentication successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index e032b43d4..9e0357223 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -137,7 +137,7 @@ type Config struct { // OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels. // These aliases affect both model listing and model routing for supported channels: - // gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi. + // gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai. // // NOTE: This does not apply to existing per-credential model alias features under: // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 7ac6b469a..2a6ebe120 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -21,6 +21,7 @@ type staticModelsJSON struct { CodexPro []*ModelInfo `json:"codex-pro"` Kimi []*ModelInfo `json:"kimi"` Antigravity []*ModelInfo `json:"antigravity"` + XAI []*ModelInfo `json:"xai"` } // GetClaudeModels returns the standard Claude model definitions. @@ -78,6 +79,11 @@ func GetAntigravityModels() []*ModelInfo { return cloneModelInfos(getModels().Antigravity) } +// GetXAIModels returns the standard xAI Grok model definitions. +func GetXAIModels() []*ModelInfo { + return cloneModelInfos(getModels().XAI) +} + // WithCodexBuiltins injects hard-coded Codex-only model definitions that should // not depend on remote models.json updates. Built-ins replace any matching IDs // already present in the provided slice. @@ -167,6 +173,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo { // - codex // - kimi // - antigravity +// - xai func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { key := strings.ToLower(strings.TrimSpace(channel)) switch key { @@ -186,6 +193,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetKimiModels() case "antigravity": return GetAntigravityModels() + case "xai", "x-ai", "grok": + return GetXAIModels() default: return nil } @@ -208,6 +217,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.CodexPro, data.Kimi, data.Antigravity, + data.XAI, } for _, models := range allModels { for _, m := range models { diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 2512a296b..ac0caffe2 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -215,6 +215,7 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string { {"codex", oldData.CodexPro, newData.CodexPro}, {"kimi", oldData.Kimi, newData.Kimi}, {"antigravity", oldData.Antigravity, newData.Antigravity}, + {"xai", oldData.XAI, newData.XAI}, } seen := make(map[string]bool, len(sections)) @@ -335,6 +336,7 @@ func validateModelsCatalog(data *staticModelsJSON) error { {name: "codex-pro", models: data.CodexPro}, {name: "kimi", models: data.Kimi}, {name: "antigravity", models: data.Antigravity}, + {name: "xai", models: data.XAI}, } for _, section := range requiredSections { diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index fa56bb42a..9837e401f 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -46,7 +46,8 @@ "levels": [ "low", "medium", - "high" + "high", + "xhigh" ] } }, @@ -2064,5 +2065,109 @@ ] } } + ], + "xai": [ + { + "id": "grok-4.3", + "object": "model", + "created": 1775606400, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 4.3", + "name": "grok-4.3", + "description": "xAI Grok 4.3 model for the Responses API.", + "context_length": 1000000, + "max_completion_tokens": 65536, + "thinking": { + "zero_allowed": true, + "levels": [ + "none", + "low", + "medium", + "high" + ] + } + }, + { + "id": "grok-4.20-0309-reasoning", + "object": "model", + "created": 1773014400, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 4.20 0309 Reasoning", + "name": "grok-4.20-0309-reasoning", + "description": "xAI Grok 4.20 0309 reasoning model for the Responses API.", + "context_length": 2000000, + "max_completion_tokens": 65536 + }, + { + "id": "grok-4.20-0309-non-reasoning", + "object": "model", + "created": 1773014400, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 4.20 0309 Non Reasoning", + "name": "grok-4.20-0309-non-reasoning", + "description": "xAI Grok 4.20 0309 non-reasoning model for the Responses API.", + "context_length": 2000000, + "max_completion_tokens": 65536 + }, + { + "id": "grok-4.20-multi-agent-0309", + "object": "model", + "created": 1773014400, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 4.20 Multi Agent 0309", + "name": "grok-4.20-multi-agent-0309", + "description": "xAI Grok 4.20 multi-agent model for the Responses API.", + "context_length": 2000000, + "max_completion_tokens": 65536, + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "grok-3-mini", + "object": "model", + "created": 1740960000, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 3 Mini", + "name": "grok-3-mini", + "description": "xAI Grok 3 Mini model for the Responses API.", + "context_length": 131072, + "max_completion_tokens": 32768, + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "grok-3-mini-fast", + "object": "model", + "created": 1740960000, + "owned_by": "xai", + "type": "xai", + "display_name": "Grok 3 Mini Fast", + "name": "grok-3-mini-fast", + "description": "xAI Grok 3 Mini Fast model for the Responses API.", + "context_length": 131072, + "max_completion_tokens": 32768, + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + } ] } diff --git a/internal/runtime/executor/xai_executor.go b/internal/runtime/executor/xai_executor.go new file mode 100644 index 000000000..b26fdfd23 --- /dev/null +++ b/internal/runtime/executor/xai_executor.go @@ -0,0 +1,570 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "github.com/tiktoken-go/tokenizer" +) + +var xaiDataTag = []byte("data:") + +// XAIExecutor is a stateless executor for xAI Grok's Responses API. +type XAIExecutor struct { + cfg *config.Config +} + +// NewXAIExecutor creates a new xAI executor. +func NewXAIExecutor(cfg *config.Config) *XAIExecutor { + return &XAIExecutor{cfg: cfg} +} + +// Identifier returns the provider identifier. +func (e *XAIExecutor) Identifier() string { + return "xai" +} + +// PrepareRequest injects xAI credentials into the outgoing HTTP request. +func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if req == nil { + return nil + } + token, _ := xaiCreds(auth) + if strings.TrimSpace(token) != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(req, attrs) + return nil +} + +// HttpRequest injects xAI credentials into the request and executes it. +func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("xai executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil { + return nil, errPrepare + } + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + return httpClient.Do(httpReq) +} + +func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + token, baseURL := xaiCreds(auth) + if baseURL == "" { + baseURL = xaiauth.DefaultAPIBaseURL + } + + prepared, err := e.prepareResponsesRequest(ctx, req, opts, true) + if err != nil { + return resp, err + } + + reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth) + defer reporter.TrackFailure(ctx, &err) + + url := strings.TrimSuffix(baseURL, "/") + "/responses" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body)) + if err != nil { + return resp, err + } + applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID) + e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body) + + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("xai executor: close response body error: %v", errClose) + } + }() + helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + helps.RecordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + helps.AppendAPIResponseChunk(ctx, e.cfg, data) + helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + helps.AppendAPIResponseChunk(ctx, e.cfg, data) + + outputItemsByIndex := make(map[int64][]byte) + var outputItemsFallback [][]byte + for _, line := range bytes.Split(data, []byte("\n")) { + if !bytes.HasPrefix(line, xaiDataTag) { + continue + } + eventData := bytes.TrimSpace(line[len(xaiDataTag):]) + switch gjson.GetBytes(eventData, "type").String() { + case "response.output_item.done": + xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback) + case "response.completed": + if detail, ok := helps.ParseCodexUsage(eventData); ok { + reporter.Publish(ctx, detail) + } + completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback) + var param any + out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m) + return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil + } + } + + return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"} +} + +func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + token, baseURL := xaiCreds(auth) + if baseURL == "" { + baseURL = xaiauth.DefaultAPIBaseURL + } + + prepared, err := e.prepareResponsesRequest(ctx, req, opts, true) + if err != nil { + return nil, err + } + + reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth) + defer reporter.TrackFailure(ctx, &err) + + url := strings.TrimSuffix(baseURL, "/") + "/responses" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body)) + if err != nil { + return nil, err + } + applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID) + e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body) + + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + data, errRead := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("xai executor: close response body error: %v", errClose) + } + if errRead != nil { + helps.RecordAPIResponseError(ctx, e.cfg, errRead) + return nil, errRead + } + helps.AppendAPIResponseChunk(ctx, e.cfg, data) + helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("xai executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 52_428_800) + var param any + outputItemsByIndex := make(map[int64][]byte) + var outputItemsFallback [][]byte + for scanner.Scan() { + line := scanner.Bytes() + helps.AppendAPIResponseChunk(ctx, e.cfg, line) + translatedLine := bytes.Clone(line) + if bytes.HasPrefix(line, xaiDataTag) { + eventData := bytes.TrimSpace(line[len(xaiDataTag):]) + switch gjson.GetBytes(eventData, "type").String() { + case "response.output_item.done": + xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback) + case "response.completed": + if detail, ok := helps.ParseCodexUsage(eventData); ok { + reporter.Publish(ctx, detail) + } + eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback) + translatedLine = append([]byte("data: "), eventData...) + } + } + chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m) + for i := range chunks { + select { + case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: + case <-ctx.Done(): + return + } + } + } + if errScan := scanner.Err(); errScan != nil { + helps.RecordAPIResponseError(ctx, e.cfg, errScan) + reporter.PublishFailure(ctx, errScan) + select { + case out <- cliproxyexecutor.StreamChunk{Err: errScan}: + case <-ctx.Done(): + } + } + }() + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil +} + +// CountTokens estimates token count for xAI Responses requests. +func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + prepared, err := e.prepareResponsesRequest(ctx, req, opts, false) + if err != nil { + return cliproxyexecutor.Response{}, err + } + enc, err := tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err) + } + count, err := enc.Count(string(prepared.body)) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err) + } + usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count) + translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON)) + return cliproxyexecutor.Response{Payload: translated}, nil +} + +// Refresh refreshes xAI OAuth credentials using the stored refresh token. +func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + log.Debugf("xai executor: refresh called") + if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled { + return refreshed, err + } + if auth == nil { + return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"} + } + refreshToken := xaiMetadataString(auth.Metadata, "refresh_token") + if refreshToken == "" { + return auth, nil + } + tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint") + svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL) + td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint) + if err != nil { + return nil, err + } + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["type"] = "xai" + auth.Metadata["auth_kind"] = "oauth" + auth.Metadata["access_token"] = td.AccessToken + if td.RefreshToken != "" { + auth.Metadata["refresh_token"] = td.RefreshToken + } + if td.IDToken != "" { + auth.Metadata["id_token"] = td.IDToken + } + if td.TokenType != "" { + auth.Metadata["token_type"] = td.TokenType + } + if td.ExpiresIn > 0 { + auth.Metadata["expires_in"] = td.ExpiresIn + } + if td.Expire != "" { + auth.Metadata["expired"] = td.Expire + } + if td.Email != "" { + auth.Metadata["email"] = td.Email + } + if td.Subject != "" { + auth.Metadata["sub"] = td.Subject + } + if tokenEndpoint != "" { + auth.Metadata["token_endpoint"] = tokenEndpoint + } + if xaiMetadataString(auth.Metadata, "base_url") == "" { + auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL + } + auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339) + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes["auth_kind"] = "oauth" + if strings.TrimSpace(auth.Attributes["base_url"]) == "" { + auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL + } + return auth, nil +} + +type xaiPreparedRequest struct { + baseModel string + from sdktranslator.Format + to sdktranslator.Format + originalPayload []byte + body []byte + sessionID string +} + +func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + from := opts.SourceFormat + to := sdktranslator.FromString("codex") + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := bytes.Clone(originalPayloadSource) + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) + + var err error + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return nil, err + } + + requestedModel := helps.PayloadRequestedModel(opts, req.Model) + requestPath := helps.PayloadRequestPath(opts) + body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath) + body, _ = sjson.SetBytes(body, "model", baseModel) + body, _ = sjson.SetBytes(body, "stream", stream) + body, _ = sjson.DeleteBytes(body, "previous_response_id") + body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") + body, _ = sjson.DeleteBytes(body, "safety_identifier") + body, _ = sjson.DeleteBytes(body, "stream_options") + body = normalizeCodexInstructions(body) + body = sanitizeXAIResponsesBody(body, baseModel) + + sessionID := xaiExecutionSessionID(req, opts) + if sessionID != "" { + body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID) + } + + return &xaiPreparedRequest{ + baseModel: baseModel, + from: from, + to: to, + originalPayload: originalPayload, + body: body, + sessionID: sessionID, + }, nil +} + +func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) { + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: headers, + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) +} + +func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) { + if auth == nil { + return "", "" + } + if auth.Attributes != nil { + token = strings.TrimSpace(auth.Attributes["api_key"]) + baseURL = strings.TrimSpace(auth.Attributes["base_url"]) + } + if auth.Metadata != nil { + if token == "" { + token = xaiMetadataString(auth.Metadata, "access_token") + } + if baseURL == "" { + baseURL = xaiMetadataString(auth.Metadata, "base_url") + } + } + return token, baseURL +} + +func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) { + r.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(token) != "" { + r.Header.Set("Authorization", "Bearer "+token) + } + if stream { + r.Header.Set("Accept", "text/event-stream") + } else { + r.Header.Set("Accept", "application/json") + } + r.Header.Set("Connection", "Keep-Alive") + if sessionID != "" { + r.Header.Set("x-grok-conv-id", sessionID) + } + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(r, attrs) +} + +func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string { + if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" { + return value + } + if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" { + return value + } + if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() { + return strings.TrimSpace(promptCacheKey.String()) + } + return "" +} + +func xaiMetadataString(meta map[string]any, key string) string { + if len(meta) == 0 || key == "" { + return "" + } + value, ok := meta[key] + if !ok || value == nil { + return "" + } + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + case fmt.Stringer: + return strings.TrimSpace(typed.String()) + default: + return strings.TrimSpace(fmt.Sprint(typed)) + } +} + +func sanitizeXAIResponsesBody(body []byte, model string) []byte { + body = removeXAIEncryptedReasoningInclude(body) + if !xaiSupportsReasoningEffort(model) { + body, _ = sjson.DeleteBytes(body, "reasoning") + } + return body +} + +func removeXAIEncryptedReasoningInclude(body []byte) []byte { + include := gjson.GetBytes(body, "include") + if !include.Exists() || !include.IsArray() { + return body + } + kept := make([]string, 0, len(include.Array())) + for _, item := range include.Array() { + value := strings.TrimSpace(item.String()) + if value == "" || value == "reasoning.encrypted_content" { + continue + } + kept = append(kept, value) + } + body, _ = sjson.SetBytes(body, "include", kept) + return body +} + +func xaiSupportsReasoningEffort(model string) bool { + name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName)) + if idx := strings.LastIndex(name, "/"); idx >= 0 { + name = name[idx+1:] + } + switch { + case strings.HasPrefix(name, "grok-3-mini"): + return true + case strings.HasPrefix(name, "grok-4.20-multi-agent"): + return true + case strings.HasPrefix(name, "grok-4.3"): + return true + default: + return false + } +} + +func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) { + itemResult := gjson.GetBytes(eventData, "item") + if !itemResult.Exists() || itemResult.Type != gjson.JSON { + return + } + outputIndexResult := gjson.GetBytes(eventData, "output_index") + if outputIndexResult.Exists() { + outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw) + return + } + *outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw)) +} + +func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte { + outputResult := gjson.GetBytes(eventData, "response.output") + shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0) + if !shouldPatchOutput { + return eventData + } + + indexes := make([]int64, 0, len(outputItemsByIndex)) + for idx := range outputItemsByIndex { + indexes = append(indexes, idx) + } + sort.Slice(indexes, func(i, j int) bool { + return indexes[i] < indexes[j] + }) + + outputArray := []byte("[]") + var buf bytes.Buffer + buf.WriteByte('[') + wrote := false + for _, idx := range indexes { + if wrote { + buf.WriteByte(',') + } + buf.Write(outputItemsByIndex[idx]) + wrote = true + } + for _, item := range outputItemsFallback { + if wrote { + buf.WriteByte(',') + } + buf.Write(item) + wrote = true + } + buf.WriteByte(']') + if wrote { + outputArray = buf.Bytes() + } + + patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray) + return patched +} diff --git a/internal/runtime/executor/xai_executor_test.go b/internal/runtime/executor/xai_executor_test.go new file mode 100644 index 000000000..a08d512bf --- /dev/null +++ b/internal/runtime/executor/xai_executor_test.go @@ -0,0 +1,138 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) { + var gotPath string + var gotAuth string + var gotGrokConvID string + var gotOriginator string + var gotAccountID string + var gotBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotGrokConvID = r.Header.Get("x-grok-conv-id") + gotOriginator = r.Header.Get("Originator") + gotAccountID = r.Header.Get("Chatgpt-Account-Id") + var errRead error + gotBody, errRead = io.ReadAll(r.Body) + if errRead != nil { + t.Fatalf("read body: %v", errRead) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n")) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "xai-auth", + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "auth_kind": "oauth", + }, + Metadata: map[string]any{ + "access_token": "xai-token", + "email": "user@example.com", + }, + } + + _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-4.3", + Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + Stream: false, + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1", + }, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotPath != "/responses" { + t.Fatalf("path = %q, want /responses", gotPath) + } + if gotAuth != "Bearer xai-token" { + t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth) + } + if gotGrokConvID != "conv-xai-1" { + t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID) + } + if gotOriginator != "" { + t.Fatalf("Originator = %q, want empty", gotOriginator) + } + if gotAccountID != "" { + t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID) + } + if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" { + t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody)) + } + if !gjson.GetBytes(gotBody, "stream").Bool() { + t.Fatalf("stream = false, want true; body=%s", string(gotBody)) + } + if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" { + t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody)) + } + for _, include := range gjson.GetBytes(gotBody, "include").Array() { + if include.String() == "reasoning.encrypted_content" { + t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody)) + } + } +} + +func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var errRead error + gotBody, errRead = io.ReadAll(r.Body) + if errRead != nil { + t.Fatalf("read body: %v", errRead) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n")) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "auth_kind": "oauth", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-4", + Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + Stream: false, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gjson.GetBytes(gotBody, "reasoning").Exists() { + t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody)) + } +} diff --git a/internal/tui/oauth_tab.go b/internal/tui/oauth_tab.go index bed17e4fa..bd3aac3f6 100644 --- a/internal/tui/oauth_tab.go +++ b/internal/tui/oauth_tab.go @@ -24,6 +24,7 @@ var oauthProviders = []oauthProvider{ {"Codex (OpenAI)", "codex-auth-url", "🟩"}, {"Antigravity", "antigravity-auth-url", "🟪"}, {"Kimi", "kimi-auth-url", "🟫"}, + {"xAI", "xai-auth-url", "⬛"}, } // oauthTabModel handles OAuth login flows. @@ -280,6 +281,8 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd { providerKey = "antigravity" case "kimi-auth-url": providerKey = "kimi" + case "xai-auth-url": + providerKey = "xai" } break } diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index fe2523150..634c69d3e 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -13,6 +13,7 @@ func init() { registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() }) + registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/auth/xai.go b/sdk/auth/xai.go new file mode 100644 index 000000000..1ab248d63 --- /dev/null +++ b/sdk/auth/xai.go @@ -0,0 +1,282 @@ +package auth + +import ( + "context" + "fmt" + "net" + "net/http" + "strings" + "time" + + xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" + "github.com/router-for-me/CLIProxyAPI/v7/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// XAIAuthenticator implements the xAI Grok OAuth loopback flow. +type XAIAuthenticator struct{} + +// NewXAIAuthenticator constructs a new xAI authenticator. +func NewXAIAuthenticator() Authenticator { + return &XAIAuthenticator{} +} + +// Provider returns the provider key for xAI. +func (XAIAuthenticator) Provider() string { + return "xai" +} + +// RefreshLead instructs the manager to refresh before token expiry. +func (XAIAuthenticator) RefreshLead() *time.Duration { + lead := xaiauth.RefreshLead() + return &lead +} + +// Login launches a local OAuth flow to obtain xAI tokens and persists them. +func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + callbackPort := xaiauth.CallbackPort + if opts.CallbackPort > 0 { + callbackPort = opts.CallbackPort + } + + pkceCodes, err := xaiauth.GeneratePKCECodes() + if err != nil { + return nil, fmt.Errorf("xai pkce generation failed: %w", err) + } + state, err := misc.GenerateRandomState() + if err != nil { + return nil, fmt.Errorf("xai state generation failed: %w", err) + } + nonce, err := misc.GenerateRandomState() + if err != nil { + return nil, fmt.Errorf("xai nonce generation failed: %w", err) + } + + authSvc := xaiauth.NewXAIAuth(cfg) + discovery, err := authSvc.Discover(ctx) + if err != nil { + return nil, err + } + + srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort) + if errServer != nil { + return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil { + log.Warnf("xai callback server shutdown error: %v", errShutdown) + } + }() + + redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath) + authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{ + AuthorizationEndpoint: discovery.AuthorizationEndpoint, + RedirectURI: redirectURI, + CodeChallenge: pkceCodes.CodeChallenge, + State: state, + Nonce: nonce, + }) + if err != nil { + return nil, err + } + + if !opts.NoBrowser { + fmt.Println("Opening browser for xAI authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(port) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + util.PrintSSHTunnelInstructions(port) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(port) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + + fmt.Println("Waiting for xAI authentication callback...") + + var result callbackResult + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + + var manualInputCh <-chan string + var manualInputErrCh <-chan error + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + default: + } + manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ") + continue + case input := <-manualInputCh: + manualInputCh = nil + manualInputErrCh = nil + manualResult, ok, errParse := parseXAIManualCallbackToken(input, state) + if errParse != nil { + return nil, errParse + } + if !ok { + continue + } + result = manualResult + break waitForCallback + case errManual := <-manualInputErrCh: + return nil, errManual + case <-timeoutTimer.C: + return nil, fmt.Errorf("xai: authentication timed out") + } + } + + if result.Error != "" { + return nil, fmt.Errorf("xai: authentication failed: %s", result.Error) + } + if result.State != state { + return nil, fmt.Errorf("xai: invalid state") + } + if result.Code == "" { + return nil, fmt.Errorf("xai: missing authorization code") + } + + bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint) + if errExchange != nil { + return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange) + } + tokenStorage := authSvc.CreateTokenStorage(bundle) + if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" { + return nil, fmt.Errorf("xai token storage missing access token") + } + + fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject) + label := strings.TrimSpace(tokenStorage.Email) + if label == "" { + label = "xAI" + } + + metadata := map[string]any{ + "type": "xai", + "access_token": tokenStorage.AccessToken, + "refresh_token": tokenStorage.RefreshToken, + "id_token": tokenStorage.IDToken, + "token_type": tokenStorage.TokenType, + "expires_in": tokenStorage.ExpiresIn, + "expired": tokenStorage.Expire, + "last_refresh": tokenStorage.LastRefresh, + "base_url": tokenStorage.BaseURL, + "redirect_uri": tokenStorage.RedirectURI, + "token_endpoint": tokenStorage.TokenEndpoint, + "auth_kind": "oauth", + } + if tokenStorage.Email != "" { + metadata["email"] = tokenStorage.Email + } + if tokenStorage.Subject != "" { + metadata["sub"] = tokenStorage.Subject + } + + fmt.Println("xAI authentication successful") + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: label, + Storage: tokenStorage, + Metadata: metadata, + Attributes: map[string]string{ + "auth_kind": "oauth", + "base_url": tokenStorage.BaseURL, + }, + }, nil +} + +func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) { + token := strings.TrimSpace(input) + if token == "" { + return callbackResult{}, false, nil + } + if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") { + return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token") + } + return callbackResult{Code: token, State: state}, true, nil +} + +func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) { + if port <= 0 { + port = xaiauth.CallbackPort + } + addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, 0, nil, err + } + port = listener.Addr().(*net.TCPAddr).Port + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + result := callbackResult{ + Code: strings.TrimSpace(q.Get("code")), + Error: strings.TrimSpace(q.Get("error")), + State: strings.TrimSpace(q.Get("state")), + } + resultCh <- result + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if result.Code != "" && result.Error == "" { + _, _ = w.Write([]byte("

Login successful

You can close this window.

")) + return + } + _, _ = w.Write([]byte("

Login failed

Please check the CLI output.

")) + }) + + srv := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + go func() { + if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") { + log.Warnf("xai callback server error: %v", errServe) + } + }() + + return srv, port, resultCh, nil +} diff --git a/sdk/auth/xai_test.go b/sdk/auth/xai_test.go new file mode 100644 index 000000000..6d755d0d1 --- /dev/null +++ b/sdk/auth/xai_test.go @@ -0,0 +1,37 @@ +package auth + +import "testing" + +func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) { + authenticator := NewXAIAuthenticator() + if authenticator.Provider() != "xai" { + t.Fatalf("Provider() = %q, want xai", authenticator.Provider()) + } + lead := authenticator.RefreshLead() + if lead == nil || *lead <= 0 { + t.Fatalf("RefreshLead() = %v, want positive duration", lead) + } +} + +func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) { + result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1") + if err != nil { + t.Fatalf("parseXAIManualCallbackToken() error = %v", err) + } + if !ok { + t.Fatal("parseXAIManualCallbackToken() ok = false, want true") + } + if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" { + t.Fatalf("Code = %q", result.Code) + } + if result.State != "state-1" { + t.Fatalf("State = %q, want state-1", result.State) + } +} + +func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) { + _, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1") + if err == nil { + t.Fatal("parseXAIManualCallbackToken() error = nil, want error") + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 823daad0b..039efab2f 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager { sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), + sdkAuth.NewXAIAuthenticator(), ) } @@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) case "kimi": s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg)) + case "xai": + s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "kimi": models = registry.GetKimiModels() models = applyExcludedModels(models, excluded) + case "xai": + models = registry.GetXAIModels() + models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { diff --git a/sdk/cliproxy/service_xai_executor_binding_test.go b/sdk/cliproxy/service_xai_executor_binding_test.go new file mode 100644 index 000000000..0329b976c --- /dev/null +++ b/sdk/cliproxy/service_xai_executor_binding_test.go @@ -0,0 +1,36 @@ +package cliproxy + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" +) + +func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) { + service := &Service{ + cfg: &config.Config{}, + coreManager: coreauth.NewManager(nil, nil, nil), + } + auth := &coreauth.Auth{ + ID: "xai-auth-1", + Provider: "xai", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + }, + } + + service.ensureExecutorsForAuth(auth) + resolved, ok := service.coreManager.Executor("xai") + if !ok || resolved == nil { + t.Fatal("expected xai executor after bind") + } + if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI { + t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved) + } + if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex { + t.Fatal("xai must not bind the codex auto executor") + } +}