diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 0455a62e..18196233 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -40,33 +40,6 @@ import ( "golang.org/x/oauth2/google" ) -var ( - oauthStatus = make(map[string]string) - oauthStatusMutex sync.RWMutex -) - -// getOAuthStatus safely retrieves an OAuth status -func getOAuthStatus(key string) (string, bool) { - oauthStatusMutex.RLock() - defer oauthStatusMutex.RUnlock() - status, ok := oauthStatus[key] - return status, ok -} - -// setOAuthStatus safely sets an OAuth status -func setOAuthStatus(key string, status string) { - oauthStatusMutex.Lock() - defer oauthStatusMutex.Unlock() - oauthStatus[key] = status -} - -// deleteOAuthStatus safely deletes an OAuth status -func deleteOAuthStatus(key string) { - oauthStatusMutex.Lock() - defer oauthStatusMutex.Unlock() - delete(oauthStatus, key) -} - var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -813,6 +786,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { return } + RegisterOAuthSession(state, "anthropic") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") @@ -839,7 +814,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - setOAuthStatus(state, "Timeout waiting for OAuth callback") + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -864,13 +839,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - setOAuthStatus(state, "Bad request") + SetOAuthSessionError(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - setOAuthStatus(state, "State code error") + SetOAuthSessionError(state, "State code error") return } @@ -903,7 +878,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - setOAuthStatus(state, "Failed to exchange authorization code for tokens") + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -914,7 +889,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) + SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -927,7 +902,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - setOAuthStatus(state, "Failed to parse token response") + SetOAuthSessionError(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -952,7 +927,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") + SetOAuthSessionError(state, "Failed to save authentication tokens") return } @@ -961,10 +936,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - deleteOAuthStatus(state) + CompleteOAuthSession(state) }() - setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -995,6 +969,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + RegisterOAuthSession(state, "gemini") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") @@ -1023,7 +999,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - setOAuthStatus(state, "OAuth flow timed out") + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1032,13 +1008,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - setOAuthStatus(state, "Authentication failed: code not found") + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1050,7 +1026,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - setOAuthStatus(state, "Failed to exchange token") + SetOAuthSessionError(state, "Failed to exchange token") return } @@ -1061,7 +1037,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - setOAuthStatus(state, "Could not get user info") + SetOAuthSessionError(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -1070,7 +1046,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - setOAuthStatus(state, "Failed to execute request") + SetOAuthSessionError(state, "Failed to execute request") return } defer func() { @@ -1082,7 +1058,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) + SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1091,7 +1067,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - setOAuthStatus(state, "Failed to get user email from token") } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1099,7 +1074,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - setOAuthStatus(state, "Failed to unmarshal token") + SetOAuthSessionError(state, "Failed to unmarshal token") return } @@ -1125,7 +1100,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) - setOAuthStatus(state, "Failed to get authenticated client") + SetOAuthSessionError(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1135,12 +1110,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - setOAuthStatus(state, "Failed to verify Cloud AI API status") + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1148,26 +1123,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - setOAuthStatus(state, "Failed to resolve project ID") + SetOAuthSessionError(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - setOAuthStatus(state, "Failed to verify Cloud AI API status") + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - setOAuthStatus(state, "Cloud AI API not enabled") + SetOAuthSessionError(state, "Cloud AI API not enabled") return } } @@ -1190,15 +1165,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - setOAuthStatus(state, "Failed to save token to file") + SetOAuthSessionError(state, "Failed to save token to file") return } - deleteOAuthStatus(state) + CompleteOAuthSession(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1234,6 +1208,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } + RegisterOAuthSession(state, "codex") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") @@ -1262,7 +1238,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - setOAuthStatus(state, "Timeout waiting for OAuth callback") + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1272,12 +1248,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - setOAuthStatus(state, "Bad Request") + SetOAuthSessionError(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - setOAuthStatus(state, "State code error") + SetOAuthSessionError(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1308,14 +1284,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - setOAuthStatus(state, "Failed to exchange authorization code for tokens") + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1326,7 +1302,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - setOAuthStatus(state, "Failed to parse token response") + SetOAuthSessionError(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1364,8 +1340,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) @@ -1373,10 +1349,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - deleteOAuthStatus(state) + CompleteOAuthSession(state) }() - setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1417,6 +1392,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { params.Set("state", state) authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() + RegisterOAuthSession(state, "antigravity") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") @@ -1443,7 +1420,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - setOAuthStatus(state, "OAuth flow timed out") + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1452,18 +1429,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - setOAuthStatus(state, "Authentication failed: state mismatch") + SetOAuthSessionError(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - setOAuthStatus(state, "Authentication failed: code not found") + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1482,7 +1459,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - setOAuthStatus(state, "Failed to build token request") + SetOAuthSessionError(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1490,7 +1467,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - setOAuthStatus(state, "Failed to exchange token") + SetOAuthSessionError(state, "Failed to exchange token") return } defer func() { @@ -1502,7 +1479,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1514,7 +1491,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - setOAuthStatus(state, "Failed to parse token response") + SetOAuthSessionError(state, "Failed to parse token response") return } @@ -1523,7 +1500,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - setOAuthStatus(state, "Failed to build user info request") + SetOAuthSessionError(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1531,7 +1508,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - setOAuthStatus(state, "Failed to execute user info request") + SetOAuthSessionError(state, "Failed to execute user info request") return } defer func() { @@ -1550,7 +1527,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) + SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1598,11 +1575,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - setOAuthStatus(state, "Failed to save token to file") + SetOAuthSessionError(state, "Failed to save token to file") return } - deleteOAuthStatus(state) + CompleteOAuthSession(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1610,7 +1587,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1632,11 +1608,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } authURL := deviceFlow.VerificationURIComplete + RegisterOAuthSession(state, "qwen") + go func() { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1655,16 +1633,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") + SetOAuthSessionError(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - deleteOAuthStatus(state) + CompleteOAuthSession(state) }() - setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1677,6 +1654,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { authSvc := iflowauth.NewIFlowAuth(h.cfg) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) + RegisterOAuthSession(state, "iflow") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") @@ -1703,7 +1682,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1716,26 +1695,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1757,8 +1736,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -1767,10 +1746,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - deleteOAuthStatus(state) + CompleteOAuthSession(state) }() - setOAuthStatus(state, "") c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2206,44 +2184,45 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec } func (h *Handler) GetAuthStatus(c *gin.Context) { - state := c.Query("state") - if statusValue, ok := getOAuthStatus(state); ok { - if statusValue != "" { - // Check for device_code prefix (Kiro AWS Builder ID flow) - // Format: "device_code|verification_url|user_code" - // Using "|" as separator because URLs contain ":" - if strings.HasPrefix(statusValue, "device_code|") { - parts := strings.SplitN(statusValue, "|", 3) - if len(parts) == 3 { - c.JSON(200, gin.H{ - "status": "device_code", - "verification_url": parts[1], - "user_code": parts[2], - }) - return - } - } - // Check for auth_url prefix (Kiro social auth flow) - // Format: "auth_url|url" - // Using "|" as separator because URLs contain ":" - if strings.HasPrefix(statusValue, "auth_url|") { - authURL := strings.TrimPrefix(statusValue, "auth_url|") - c.JSON(200, gin.H{ - "status": "auth_url", - "url": authURL, + state := strings.TrimSpace(c.Query("state")) + if state == "" { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + + _, status, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + if status != "" { + if strings.HasPrefix(status, "device_code|") { + parts := strings.SplitN(status, "|", 3) + if len(parts) == 3 { + c.JSON(http.StatusOK, gin.H{ + "status": "device_code", + "verification_url": parts[1], + "user_code": parts[2], }) return } - // Otherwise treat as error - c.JSON(200, gin.H{"status": "error", "error": statusValue}) - } else { - c.JSON(200, gin.H{"status": "wait"}) + } + if strings.HasPrefix(status, "auth_url|") { + authURL := strings.TrimPrefix(status, "auth_url|") + c.JSON(http.StatusOK, gin.H{ + "status": "auth_url", + "url": authURL, + }) return } - } else { - c.JSON(200, gin.H{"status": "ok"}) + c.JSON(http.StatusOK, gin.H{"status": "error", "error": status}) + return } - deleteOAuthStatus(state) + c.JSON(http.StatusOK, gin.H{"status": "wait"}) } const kiroCallbackPort = 9876 @@ -2263,31 +2242,33 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { switch method { case "aws", "builder-id": + RegisterOAuthSession(state, "kiro") + // AWS Builder ID uses device code flow (no callback needed) go func() { ssoClient := kiroauth.NewSSOOIDCClient(h.cfg) // Step 1: Register client fmt.Println("Registering client...") - regResp, err := ssoClient.RegisterClient(ctx) - if err != nil { - log.Errorf("Failed to register client: %v", err) - setOAuthStatus(state, "Failed to register client") + regResp, errRegister := ssoClient.RegisterClient(ctx) + if errRegister != nil { + log.Errorf("Failed to register client: %v", errRegister) + SetOAuthSessionError(state, "Failed to register client") return } // Step 2: Start device authorization fmt.Println("Starting device authorization...") - authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) - if err != nil { - log.Errorf("Failed to start device auth: %v", err) - setOAuthStatus(state, "Failed to start device authorization") + authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if errAuth != nil { + log.Errorf("Failed to start device auth: %v", errAuth) + SetOAuthSessionError(state, "Failed to start device authorization") return } - // Store the verification URL for the frontend to display - // Using "|" as separator because URLs contain ":" - setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) + // Store the verification URL for the frontend to display. + // Using "|" as separator because URLs contain ":". + SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) // Step 3: Poll for token fmt.Println("Waiting for authorization...") @@ -2300,12 +2281,12 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { for time.Now().Before(deadline) { select { case <-ctx.Done(): - setOAuthStatus(state, "Authorization cancelled") + SetOAuthSessionError(state, "Authorization cancelled") return case <-time.After(interval): - tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) - if err != nil { - errStr := err.Error() + tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if errToken != nil { + errStr := errToken.Error() if strings.Contains(errStr, "authorization_pending") { continue } @@ -2313,8 +2294,8 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { interval += 5 * time.Second continue } - log.Errorf("Token creation failed: %v", err) - setOAuthStatus(state, "Token creation failed") + log.Errorf("Token creation failed: %v", errToken) + SetOAuthSessionError(state, "Token creation failed") return } @@ -2351,7 +2332,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") + SetOAuthSessionError(state, "Failed to save authentication tokens") return } @@ -2359,18 +2340,20 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { if email != "" { fmt.Printf("Authenticated as: %s\n", email) } - deleteOAuthStatus(state) + CompleteOAuthSession(state) return } } - setOAuthStatus(state, "Authorization timed out") + SetOAuthSessionError(state, "Authorization timed out") }() // Return immediately with the state for polling - c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"}) + c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"}) case "google", "github": + RegisterOAuthSession(state, "kiro") + // Social auth uses protocol handler - for WEB UI we use a callback forwarder provider := "Google" if method == "github" { @@ -2400,10 +2383,10 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { socialClient := kiroauth.NewSocialAuthClient(h.cfg) // Generate PKCE codes - codeVerifier, codeChallenge, err := generateKiroPKCE() - if err != nil { - log.Errorf("Failed to generate PKCE: %v", err) - setOAuthStatus(state, "Failed to generate PKCE") + codeVerifier, codeChallenge, errPKCE := generateKiroPKCE() + if errPKCE != nil { + log.Errorf("Failed to generate PKCE: %v", errPKCE) + SetOAuthSessionError(state, "Failed to generate PKCE") return } @@ -2416,9 +2399,9 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { state, ) - // Store auth URL for frontend - // Using "|" as separator because URLs contain ":" - setOAuthStatus(state, "auth_url|"+authURL) + // Store auth URL for frontend. + // Using "|" as separator because URLs contain ":". + SetOAuthSessionError(state, "auth_url|"+authURL) // Wait for callback file waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state)) @@ -2427,27 +2410,27 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - setOAuthStatus(state, "OAuth flow timed out") + SetOAuthSessionError(state, "OAuth flow timed out") return } - if data, errR := os.ReadFile(waitFile); errR == nil { + if data, errRead := os.ReadFile(waitFile); errRead == nil { var m map[string]string _ = json.Unmarshal(data, &m) _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - setOAuthStatus(state, "Authentication failed") + SetOAuthSessionError(state, "Authentication failed") return } if m["state"] != state { log.Errorf("State mismatch") - setOAuthStatus(state, "State mismatch") + SetOAuthSessionError(state, "State mismatch") return } code := m["code"] if code == "" { log.Error("No authorization code received") - setOAuthStatus(state, "No authorization code received") + SetOAuthSessionError(state, "No authorization code received") return } @@ -2461,7 +2444,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq) if errToken != nil { log.Errorf("Failed to exchange code for tokens: %v", errToken) - setOAuthStatus(state, "Failed to exchange code for tokens") + SetOAuthSessionError(state, "Failed to exchange code for tokens") return } @@ -2501,7 +2484,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - setOAuthStatus(state, "Failed to save authentication tokens") + SetOAuthSessionError(state, "Failed to save authentication tokens") return } @@ -2509,15 +2492,14 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { if email != "" { fmt.Printf("Authenticated as: %s\n", email) } - deleteOAuthStatus(state) + CompleteOAuthSession(state) return } time.Sleep(500 * time.Millisecond) } }() - setOAuthStatus(state, "") - c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"}) + c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"}) default: c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"}) @@ -2527,8 +2509,8 @@ func (h *Handler) RequestKiroToken(c *gin.Context) { // generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth. func generateKiroPKCE() (verifier, challenge string, err error) { b := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, b); err != nil { - return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead) } verifier = base64.RawURLEncoding.EncodeToString(b) diff --git a/internal/api/handlers/management/oauth_callback.go b/internal/api/handlers/management/oauth_callback.go new file mode 100644 index 00000000..c69a332e --- /dev/null +++ b/internal/api/handlers/management/oauth_callback.go @@ -0,0 +1,100 @@ +package management + +import ( + "errors" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" +) + +type oauthCallbackRequest struct { + Provider string `json:"provider"` + RedirectURL string `json:"redirect_url"` + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func (h *Handler) PostOAuthCallback(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"}) + return + } + + var req oauthCallbackRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"}) + return + } + + canonicalProvider, err := NormalizeOAuthProvider(req.Provider) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"}) + return + } + + state := strings.TrimSpace(req.State) + code := strings.TrimSpace(req.Code) + errMsg := strings.TrimSpace(req.Error) + + if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" { + u, errParse := url.Parse(rawRedirect) + if errParse != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"}) + return + } + q := u.Query() + if state == "" { + state = strings.TrimSpace(q.Get("state")) + } + if code == "" { + code = strings.TrimSpace(q.Get("code")) + } + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error")) + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error_description")) + } + } + } + + if state == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + return + } + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + if code == "" && errMsg == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"}) + return + } + + sessionProvider, sessionStatus, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"}) + return + } + if sessionStatus != "" { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + if !strings.EqualFold(sessionProvider, canonicalProvider) { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"}) + return + } + + if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil { + if errors.Is(errWrite, errOAuthSessionNotPending) { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "ok"}) +} diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go new file mode 100644 index 00000000..aaa4b7da --- /dev/null +++ b/internal/api/handlers/management/oauth_sessions.go @@ -0,0 +1,265 @@ +package management + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + oauthSessionTTL = 10 * time.Minute + maxOAuthStateLength = 128 +) + +var ( + errInvalidOAuthState = errors.New("invalid oauth state") + errUnsupportedOAuthFlow = errors.New("unsupported oauth provider") + errOAuthSessionNotPending = errors.New("oauth session is not pending") +) + +type oauthSession struct { + Provider string + Status string + CreatedAt time.Time + ExpiresAt time.Time +} + +type oauthSessionStore struct { + mu sync.RWMutex + ttl time.Duration + sessions map[string]oauthSession +} + +func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore { + if ttl <= 0 { + ttl = oauthSessionTTL + } + return &oauthSessionStore{ + ttl: ttl, + sessions: make(map[string]oauthSession), + } +} + +func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) { + for state, session := range s.sessions { + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + delete(s.sessions, state) + } + } +} + +func (s *oauthSessionStore) Register(state, provider string) { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + if state == "" || provider == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + s.sessions[state] = oauthSession{ + Provider: provider, + Status: "", + CreatedAt: now, + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *oauthSessionStore) SetError(state, message string) { + state = strings.TrimSpace(state) + message = strings.TrimSpace(message) + if state == "" { + return + } + if message == "" { + message = "Authentication failed" + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return + } + session.Status = message + session.ExpiresAt = now.Add(s.ttl) + s.sessions[state] = session +} + +func (s *oauthSessionStore) Complete(state string) { + state = strings.TrimSpace(state) + if state == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + delete(s.sessions, state) +} + +func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { + state = strings.TrimSpace(state) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + return session, ok +} + +func (s *oauthSessionStore) IsPending(state, provider string) bool { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return false + } + if session.Status != "" { + if !strings.EqualFold(session.Provider, "kiro") { + return false + } + if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") { + return false + } + } + if provider == "" { + return true + } + return strings.EqualFold(session.Provider, provider) +} + +var oauthSessions = newOAuthSessionStore(oauthSessionTTL) + +func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) } + +func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) } + +func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } + +func GetOAuthSession(state string) (provider string, status string, ok bool) { + session, ok := oauthSessions.Get(state) + if !ok { + return "", "", false + } + return session.Provider, session.Status, true +} + +func IsOAuthSessionPending(state, provider string) bool { + return oauthSessions.IsPending(state, provider) +} + +func ValidateOAuthState(state string) error { + trimmed := strings.TrimSpace(state) + if trimmed == "" { + return fmt.Errorf("%w: empty", errInvalidOAuthState) + } + if len(trimmed) > maxOAuthStateLength { + return fmt.Errorf("%w: too long", errInvalidOAuthState) + } + if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") { + return fmt.Errorf("%w: contains path separator", errInvalidOAuthState) + } + if strings.Contains(trimmed, "..") { + return fmt.Errorf("%w: contains '..'", errInvalidOAuthState) + } + for _, r := range trimmed { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '-' || r == '_' || r == '.': + default: + return fmt.Errorf("%w: invalid character", errInvalidOAuthState) + } + } + return nil +} + +func NormalizeOAuthProvider(provider string) (string, error) { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "anthropic", "claude": + return "anthropic", nil + case "codex", "openai": + return "codex", nil + case "gemini", "google": + return "gemini", nil + case "iflow", "i-flow": + return "iflow", nil + case "antigravity", "anti-gravity": + return "antigravity", nil + case "qwen": + return "qwen", nil + case "kiro": + return "kiro", nil + default: + return "", errUnsupportedOAuthFlow + } +} + +type oauthCallbackFilePayload struct { + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { + if strings.TrimSpace(authDir) == "" { + return "", fmt.Errorf("auth dir is empty") + } + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if err := ValidateOAuthState(state); err != nil { + return "", err + } + + fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state) + filePath := filepath.Join(authDir, fileName) + payload := oauthCallbackFilePayload{ + Code: strings.TrimSpace(code), + State: strings.TrimSpace(state), + Error: strings.TrimSpace(errorMessage), + } + data, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal oauth callback payload: %w", err) + } + if err := os.WriteFile(filePath, data, 0o600); err != nil { + return "", fmt.Errorf("write oauth callback file: %w", err) + } + return filePath, nil +} + +func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if !IsOAuthSessionPending(state, canonicalProvider) { + return "", errOAuthSessionNotPending + } + return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 0abd943a..911d2b7d 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -95,6 +95,20 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc { } } +// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere. +func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc { + return func(c *gin.Context) { + path := c.Request.URL.Path + for _, prefix := range prefixes { + if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') { + c.Next() + return + } + } + auth(c) + } +} + // registerManagementRoutes registers Amp management proxy routes // These routes proxy through to the Amp control plane for OAuth, user management, etc. // Uses dynamic middleware and proxy getter for hot-reload support. @@ -109,8 +123,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha ampAPI.Use(m.localhostOnlyMiddleware()) // Apply authentication middleware - requires valid API key in Authorization header + var authWithBypass gin.HandlerFunc if auth != nil { ampAPI.Use(auth) + authWithBypass = wrapManagementAuth(auth, "/threads", "/auth") } // Dynamic proxy handler that uses m.getProxy() for hot-reload support @@ -156,8 +172,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // Root-level routes that AMP CLI expects without /api prefix // These need the same security middleware as the /api/* routes (dynamic for hot-reload) rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} - if auth != nil { - rootMiddleware = append(rootMiddleware, auth) + if authWithBypass != nil { + rootMiddleware = append(rootMiddleware, authWithBypass) } engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) diff --git a/internal/api/server.go b/internal/api/server.go index 970371e0..f90a1e36 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -360,10 +360,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") - // Persist to a temporary file keyed by state + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -373,9 +374,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -385,9 +388,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -397,9 +402,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -409,9 +416,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -421,9 +430,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -596,6 +607,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) + mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index af244b60..5a5a29a9 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -242,7 +242,7 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) { var modelsWithDefaultThinking = map[string]bool{ "gemini-3-pro-preview": true, "gemini-3-pro-image-preview": true, - "gemini-3-flash-preview": true, + // "gemini-3-flash-preview": true, } // ModelHasDefaultThinking returns true if the model should have thinking enabled by default.