diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 3fe6e678b..291f6ef1e 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -2081,7 +2081,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { log.Warnf("antigravity: failed to fetch project ID: %v", errProject) } else { projectID = fetchedProjectID - log.Infof("antigravity: obtained project ID %s", projectID) + log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID)) } } @@ -2125,7 +2125,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { CompleteOAuthSessionsByProvider("antigravity") fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { - fmt.Printf("Using GCP project: %s\n", projectID) + fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID)) } fmt.Println("You can now use Antigravity services through this CLI") }() diff --git a/internal/auth/antigravity/auth.go b/internal/auth/antigravity/auth.go index 7bee09bb6..e1fead36d 100644 --- a/internal/auth/antigravity/auth.go +++ b/internal/auth/antigravity/auth.go @@ -48,10 +48,76 @@ func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *Antigravit } } -func (o *AntigravityAuth) loadCodeAssistUserAgent() string { +func (o *AntigravityAuth) shortUserAgent() string { + return misc.AntigravityRequestUserAgent("") +} + +func (o *AntigravityAuth) nodeUserAgent() string { return misc.AntigravityLoadCodeAssistUserAgent("") } +func antigravityLoadCodeAssistMetadata() map[string]string { + return map[string]string{ + "ideType": "ANTIGRAVITY", + } +} + +func antigravityControlPlaneMetadata(userAgent string) map[string]string { + return map[string]string{ + "ide_type": "ANTIGRAVITY", + "ide_version": misc.AntigravityVersionFromUserAgent(userAgent), + "ide_name": "antigravity", + } +} + +func extractCloudaicompanionProject(data map[string]any) string { + if data == nil { + return "" + } + for _, key := range []string{"cloudaicompanionProject", "projectId", "project"} { + switch value := data[key].(type) { + case string: + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + case map[string]any: + if id, ok := value["id"].(string); ok { + if trimmed := strings.TrimSpace(id); trimmed != "" { + return trimmed + } + } + } + } + return "" +} + +func defaultAntigravityTierID(loadResp map[string]any) string { + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); !okDefault || !isDefault { + continue + } + if id, okID := tier["id"].(string); okID { + if trimmed := strings.TrimSpace(id); trimmed != "" { + return trimmed + } + } + } + } + if currentTier, okTier := loadResp["currentTier"].(map[string]any); okTier { + if id, okID := currentTier["id"].(string); okID { + if trimmed := strings.TrimSpace(id); trimmed != "" { + return trimmed + } + } + } + return "free-tier" +} + // BuildAuthURL generates the OAuth authorization URL. func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string { if strings.TrimSpace(redirectURI) == "" { @@ -123,7 +189,7 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) return "", fmt.Errorf("antigravity userinfo: create request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", o.loadCodeAssistUserAgent()) + req.Header.Set("User-Agent", o.shortUserAgent()) resp, errDo := o.httpClient.Do(req) if errDo != nil { @@ -159,13 +225,9 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) // FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) { - userAgent := o.loadCodeAssistUserAgent() + userAgent := o.shortUserAgent() loadReqBody := map[string]any{ - "metadata": map[string]string{ - "ide_type": "ANTIGRAVITY", - "ide_version": misc.AntigravityVersionFromUserAgent(userAgent), - "ide_name": "antigravity", - }, + "metadata": antigravityLoadCodeAssistMetadata(), } rawBody, errMarshal := json.Marshal(loadReqBody) @@ -179,9 +241,9 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "*/*") req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) - req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA) resp, errDo := o.httpClient.Do(req) if errDo != nil { @@ -207,40 +269,16 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string return "", fmt.Errorf("decode response: %w", errDecode) } - // Extract projectID from response - projectID := "" - if id, ok := loadResp["cloudaicompanionProject"].(string); ok { - projectID = strings.TrimSpace(id) - } - if projectID == "" { - if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok { - if id, okID := projectMap["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } + projectID := extractCloudaicompanionProject(loadResp) if projectID == "" { - tierID := "legacy-tier" - if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { - for _, rawTier := range tiers { - tier, okTier := rawTier.(map[string]any) - if !okTier { - continue - } - if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { - if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { - tierID = strings.TrimSpace(id) - break - } - } - } - } - - projectID, err = o.OnboardUser(ctx, accessToken, tierID) + projectID, err = o.OnboardUser(ctx, accessToken, defaultAntigravityTierID(loadResp)) if err != nil { return "", err } + if projectID == "" { + return "", fmt.Errorf("project id not found in loadCodeAssist or onboardUser response") + } return projectID, nil } @@ -250,14 +288,10 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string // OnboardUser attempts to fetch the project ID via onboardUser by polling for completion func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) { log.Infof("Antigravity: onboarding user with tier: %s", tierID) - userAgent := o.loadCodeAssistUserAgent() + userAgent := o.nodeUserAgent() requestBody := map[string]any{ - "tierId": tierID, - "metadata": map[string]string{ - "ide_type": "ANTIGRAVITY", - "ide_version": misc.AntigravityVersionFromUserAgent(userAgent), - "ide_name": "antigravity", - }, + "tier_id": tierID, + "metadata": antigravityControlPlaneMetadata(userAgent), } rawBody, errMarshal := json.Marshal(requestBody) @@ -276,13 +310,14 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s } reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) - endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion) + endpointURL := fmt.Sprintf("%s/%s:onboardUser", DailyAPIEndpoint, APIVersion) req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) if errRequest != nil { cancel() return "", fmt.Errorf("create request: %w", errRequest) } req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "*/*") req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA) @@ -312,18 +347,11 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s if done, okDone := data["done"].(bool); okDone && done { projectID := "" if responseData, okResp := data["response"].(map[string]any); okResp { - switch projectValue := responseData["cloudaicompanionProject"].(type) { - case map[string]any: - if id, okID := projectValue["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - case string: - projectID = strings.TrimSpace(projectValue) - } + projectID = extractCloudaicompanionProject(responseData) } if projectID != "" { - log.Infof("Successfully fetched project_id: %s", projectID) + log.Infof("Successfully fetched project_id: %s", util.HideAPIKey(projectID)) return projectID, nil } @@ -346,5 +374,5 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) } - return "", nil + return "", fmt.Errorf("onboard user did not complete after %d attempts", maxAttempts) } diff --git a/internal/auth/antigravity/auth_test.go b/internal/auth/antigravity/auth_test.go new file mode 100644 index 000000000..ce1de8548 --- /dev/null +++ b/internal/auth/antigravity/auth_test.go @@ -0,0 +1,127 @@ +package antigravity + +import ( + "context" + "io" + "net/http" + "strings" + "testing" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestFetchProjectIDFromLoadCodeAssist(t *testing.T) { + auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { + t.Fatalf("unexpected request URL: %s", req.URL.String()) + } + assertLoadCodeAssistHeaders(t, req) + assertJSONContains(t, req, `"ideType":"ANTIGRAVITY"`) + return jsonResponse(`{"cloudaicompanionProject":"cogent-snow-4mnnp"}`), nil + })}) + + projectID, err := auth.FetchProjectID(context.Background(), "access-token") + if err != nil { + t.Fatalf("FetchProjectID error: %v", err) + } + if projectID != "cogent-snow-4mnnp" { + t.Fatalf("projectID = %q", projectID) + } +} + +func TestFetchProjectIDFallsBackToDailyOnboardUser(t *testing.T) { + var sawOnboard bool + auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + switch req.URL.String() { + case "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist": + assertLoadCodeAssistHeaders(t, req) + return jsonResponse(`{"allowedTiers":[{"id":"free-tier","isDefault":true}]}`), nil + case "https://daily-cloudcode-pa.googleapis.com/v1internal:onboardUser": + sawOnboard = true + assertOnboardUserHeaders(t, req) + assertJSONContains(t, req, `"tier_id":"free-tier"`) + assertJSONContains(t, req, `"ide_type":"ANTIGRAVITY"`) + return jsonResponse(`{ + "done": true, + "response": { + "cloudaicompanionProject": { + "id": "cogent-snow-4mnnp", + "name": "cogent-snow-4mnnp", + "projectNumber": "22597072101" + } + } + }`), nil + default: + t.Fatalf("unexpected request URL: %s", req.URL.String()) + return nil, nil + } + })}) + + projectID, err := auth.FetchProjectID(context.Background(), "access-token") + if err != nil { + t.Fatalf("FetchProjectID error: %v", err) + } + if !sawOnboard { + t.Fatalf("expected onboardUser fallback") + } + if projectID != "cogent-snow-4mnnp" { + t.Fatalf("projectID = %q", projectID) + } +} + +func assertLoadCodeAssistHeaders(t *testing.T, req *http.Request) { + t.Helper() + if got := req.Header.Get("Authorization"); got != "Bearer access-token" { + t.Fatalf("Authorization = %q", got) + } + if got := req.Header.Get("Accept"); got != "*/*" { + t.Fatalf("Accept = %q", got) + } + if got := req.Header.Get("X-Goog-Api-Client"); got != "" { + t.Fatalf("X-Goog-Api-Client = %q, want empty", got) + } + if got := req.Header.Get("User-Agent"); strings.Contains(got, "google-api-nodejs-client/") { + t.Fatalf("User-Agent = %q", got) + } +} + +func assertOnboardUserHeaders(t *testing.T, req *http.Request) { + t.Helper() + if got := req.Header.Get("Authorization"); got != "Bearer access-token" { + t.Fatalf("Authorization = %q", got) + } + if got := req.Header.Get("Accept"); got != "*/*" { + t.Fatalf("Accept = %q", got) + } + if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" { + t.Fatalf("X-Goog-Api-Client = %q", got) + } + if got := req.Header.Get("User-Agent"); !strings.Contains(got, "google-api-nodejs-client/10.3.0") { + t.Fatalf("User-Agent = %q", got) + } +} + +func assertJSONContains(t *testing.T, req *http.Request, want string) { + t.Helper() + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + bodyText := string(body) + req.Body = io.NopCloser(strings.NewReader(bodyText)) + if !strings.Contains(bodyText, want) { + t.Fatalf("body missing %s: %s", want, bodyText) + } +} + +func jsonResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} diff --git a/internal/auth/antigravity/constants.go b/internal/auth/antigravity/constants.go index 61e736971..2ba464d44 100644 --- a/internal/auth/antigravity/constants.go +++ b/internal/auth/antigravity/constants.go @@ -26,6 +26,7 @@ const ( // Antigravity API configuration const ( - APIEndpoint = "https://cloudcode-pa.googleapis.com" - APIVersion = "v1internal" + APIEndpoint = "https://cloudcode-pa.googleapis.com" + DailyAPIEndpoint = "https://daily-cloudcode-pa.googleapis.com" + APIVersion = "v1internal" ) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index adbc5c9a2..5527bece9 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1415,6 +1415,41 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au return updated, nil } +func (e *AntigravityExecutor) ShouldPrepareRequestAuth(auth *cliproxyauth.Auth) bool { + return antigravityProjectIDFromAuth(auth) == "" +} + +func (e *AntigravityExecutor) PrepareRequestAuth(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil || !e.ShouldPrepareRequestAuth(auth) { + return nil, nil + } + + updated := auth.Clone() + token, refreshedAuth, errToken := e.ensureAccessToken(ctx, updated) + if errToken != nil { + return nil, errToken + } + if refreshedAuth != nil { + updated = refreshedAuth + } + if antigravityProjectIDFromAuth(updated) != "" { + return updated, nil + } + + projectID, errProject := e.fetchAntigravityProjectID(ctx, updated, token) + if errProject != nil { + return nil, missingAntigravityProjectIDError(errProject) + } + if projectID == "" { + return nil, missingAntigravityProjectIDError(nil) + } + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["project_id"] = projectID + return updated, nil +} + // CountTokens counts tokens for the given request using the Antigravity API. func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName @@ -1752,34 +1787,67 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au return nil } - if auth.Metadata["project_id"] != nil { + if antigravityProjectIDFromAuth(auth) != "" { return nil } - token := strings.TrimSpace(accessToken) - if token == "" { - token = metaStringValue(auth.Metadata, "access_token") - } - if token == "" { - return nil - } - - httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) - projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient) + projectID, errFetch := e.fetchAntigravityProjectID(ctx, auth, accessToken) if errFetch != nil { return errFetch } - if strings.TrimSpace(projectID) == "" { + if projectID == "" { return nil } if auth.Metadata == nil { auth.Metadata = make(map[string]any) } - auth.Metadata["project_id"] = strings.TrimSpace(projectID) + auth.Metadata["project_id"] = projectID return nil } +func (e *AntigravityExecutor) fetchAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (string, error) { + token := strings.TrimSpace(accessToken) + if token == "" { + token = metaStringValue(auth.Metadata, "access_token") + } + if token == "" { + return "", nil + } + + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) + projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient) + if errFetch != nil { + return "", errFetch + } + return strings.TrimSpace(projectID), nil +} + +func (e *AntigravityExecutor) projectIDForRequest(_ context.Context, auth *cliproxyauth.Auth, _ string) (string, error) { + if projectID := antigravityProjectIDFromAuth(auth); projectID != "" { + return projectID, nil + } + return "", missingAntigravityProjectIDError(nil) +} + +func antigravityProjectIDFromAuth(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return "" + } + if pid, ok := auth.Metadata["project_id"].(string); ok { + return strings.TrimSpace(pid) + } + return "" +} + +func missingAntigravityProjectIDError(cause error) statusErr { + msg := "antigravity auth missing project_id" + if cause != nil { + msg = fmt.Sprintf("%s: %v", msg, cause) + } + return statusErr{code: http.StatusBadRequest, msg: msg} +} + func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) { if auth == nil || strings.TrimSpace(auth.ID) == "" { return @@ -1792,19 +1860,17 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex return } - userAgent := resolveLoadCodeAssistUserAgent(auth) + userAgent := resolveUserAgent(auth) loadReqBody, errMarshal := json.Marshal(map[string]any{ "metadata": map[string]string{ - "ide_type": "ANTIGRAVITY", - "ide_version": misc.AntigravityVersionFromUserAgent(userAgent), - "ide_name": "antigravity", + "ideType": "ANTIGRAVITY", }, }) if errMarshal != nil { log.Debugf("antigravity executor: marshal loadCodeAssist request error: %v", errMarshal) return } - baseURL := buildBaseURL(auth) + baseURL := antigravityLoadCodeAssistBaseURL(auth) endpointURL := strings.TrimSuffix(baseURL, "/") + "/v1internal:loadCodeAssist" httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(loadReqBody)) if errReq != nil { @@ -1812,9 +1878,9 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex return } httpReq.Header.Set("Authorization", "Bearer "+token) + httpReq.Header.Set("Accept", "*/*") httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("User-Agent", userAgent) - httpReq.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) httpResp, errDo := httpClient.Do(httpReq) @@ -1909,12 +1975,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau requestURL.WriteString(url.QueryEscape(alt)) } - // Extract project_id from auth metadata if available - projectID := "" - if auth != nil && auth.Metadata != nil { - if pid, ok := auth.Metadata["project_id"].(string); ok { - projectID = strings.TrimSpace(pid) - } + projectID, errProject := e.projectIDForRequest(ctx, auth, token) + if errProject != nil { + return nil, errProject } payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", modelName) @@ -2100,6 +2163,13 @@ func buildBaseURL(auth *cliproxyauth.Auth) string { return antigravityBaseURLDaily } +func antigravityLoadCodeAssistBaseURL(auth *cliproxyauth.Auth) string { + if base := resolveCustomAntigravityBaseURL(auth); base != "" { + return base + } + return antigravityBaseURLProd +} + func resolveHost(base string) string { parsed, errParse := url.Parse(base) if errParse != nil { @@ -2338,11 +2408,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b } template, _ = sjson.SetBytes(template, "requestType", reqType) - // Use real project ID from auth if available, otherwise generate random (legacy fallback) if projectID != "" { template, _ = sjson.SetBytes(template, "project", projectID) } else { - template, _ = sjson.SetBytes(template, "project", generateProjectID()) + template, _ = sjson.DeleteBytes(template, "project") } if isImageModel { @@ -2391,14 +2460,3 @@ func generateStableSessionID(payload []byte) string { } return generateSessionID() } - -func generateProjectID() string { - adjectives := []string{"useful", "bright", "swift", "calm", "bold"} - nouns := []string{"fuze", "wave", "spark", "flow", "core"} - randSourceMutex.Lock() - adj := adjectives[randSource.Intn(len(adjectives))] - noun := nouns[randSource.Intn(len(nouns))] - randSourceMutex.Unlock() - randomPart := strings.ToLower(uuid.NewString())[:5] - return adj + "-" + noun + "-" + randomPart -} diff --git a/internal/runtime/executor/antigravity_executor_buildrequest_test.go b/internal/runtime/executor/antigravity_executor_buildrequest_test.go index f0711752e..e47a500b2 100644 --- a/internal/runtime/executor/antigravity_executor_buildrequest_test.go +++ b/internal/runtime/executor/antigravity_executor_buildrequest_test.go @@ -4,7 +4,10 @@ import ( "context" "encoding/json" "io" + "net/http" + "strings" "testing" + "time" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) @@ -90,6 +93,82 @@ func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithEmptyToolsArray(t *t assertNonSchemaRequestPreserved(t, body) } +func TestAntigravityBuildRequest_UsesAuthProjectID(t *testing.T) { + body := buildRequestBodyFromRawPayload(t, "gemini-3.1-pro", []byte(`{ + "request": { + "contents": [ + { + "role": "user", + "parts": [{"text": "hello"}] + } + ] + } + }`)) + + if got, ok := body["project"].(string); !ok || got != "project-1" { + t.Fatalf("project should come from auth metadata, got=%v", body["project"]) + } +} + +func TestAntigravityPrepareRequestAuth_FetchesMissingProjectID(t *testing.T) { + executor := &AntigravityExecutor{} + auth := &cliproxyauth.Auth{Metadata: map[string]any{ + "access_token": "token", + "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), + }} + ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { + t.Fatalf("unexpected project discovery request: %s", req.URL.String()) + } + if got := req.Header.Get("X-Goog-Api-Client"); got != "" { + t.Fatalf("X-Goog-Api-Client = %q, want empty", got) + } + raw, errRead := io.ReadAll(req.Body) + if errRead != nil { + t.Fatalf("read discovery body: %v", errRead) + } + if !strings.Contains(string(raw), `"ideType":"ANTIGRAVITY"`) { + t.Fatalf("unexpected discovery body: %s", string(raw)) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"cloudaicompanionProject":"fetched-project"}`)), + }, nil + })) + + updated, err := executor.PrepareRequestAuth(ctx, auth) + if err != nil { + t.Fatalf("PrepareRequestAuth error: %v", err) + } + if updated == nil { + t.Fatalf("PrepareRequestAuth returned nil auth") + } + if _, ok := auth.Metadata["project_id"]; ok { + t.Fatalf("original auth metadata should not be mutated") + } + if got, ok := updated.Metadata["project_id"].(string); !ok || got != "fetched-project" { + t.Fatalf("updated auth metadata project_id = %v, want fetched-project", updated.Metadata["project_id"]) + } +} + +func TestAntigravityBuildRequest_RejectsMissingProjectID(t *testing.T) { + executor := &AntigravityExecutor{} + auth := &cliproxyauth.Auth{Metadata: map[string]any{}} + + _, err := executor.buildRequest(context.Background(), auth, "token", "gemini-3.1-pro", []byte(`{"request":{}}`), false, "", "https://example.com") + if err == nil { + t.Fatalf("buildRequest should fail when auth has no project_id") + } + status, ok := err.(interface{ StatusCode() int }) + if !ok { + t.Fatalf("error should expose status code, got %T", err) + } + if got := status.StatusCode(); got != http.StatusBadRequest { + t.Fatalf("status code = %d, want %d", got, http.StatusBadRequest) + } +} + func assertNonSchemaRequestPreserved(t *testing.T, body map[string]any) { t.Helper() @@ -172,13 +251,19 @@ func buildRequestBodyFromRawPayload(t *testing.T, modelName string, payload []by t.Helper() executor := &AntigravityExecutor{} - auth := &cliproxyauth.Auth{} + auth := &cliproxyauth.Auth{Metadata: map[string]any{"project_id": "project-1"}} req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com") if err != nil { t.Fatalf("buildRequest error: %v", err) } + return requestBody(t, req) +} + +func requestBody(t *testing.T, req *http.Request) map[string]any { + t.Helper() + raw, err := io.ReadAll(req.Body) if err != nil { t.Fatalf("read request body error: %v", err) diff --git a/internal/runtime/executor/antigravity_executor_credits_test.go b/internal/runtime/executor/antigravity_executor_credits_test.go index e16e64434..ac523339d 100644 --- a/internal/runtime/executor/antigravity_executor_credits_test.go +++ b/internal/runtime/executor/antigravity_executor_credits_test.go @@ -444,24 +444,25 @@ func TestUpdateAntigravityCreditsBalance_LoadCodeAssistUserAgent(t *testing.T) { t.Cleanup(resetAntigravityCreditsRetryState) exec := NewAntigravityExecutor(&config.Config{}) - const userAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0" + const configuredUserAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0" + const loadCodeAssistUserAgent = "antigravity/1.23.2 windows/amd64" auth := &cliproxyauth.Auth{ ID: "auth-load-code-assist-ua", - Attributes: map[string]string{"user_agent": userAgent}, + Attributes: map[string]string{"user_agent": configuredUserAgent}, } ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { t.Fatalf("unexpected request url %s", req.URL.String()) } - if got := req.Header.Get("User-Agent"); got != userAgent { - t.Fatalf("User-Agent = %q, want %q", got, userAgent) + if got := req.Header.Get("User-Agent"); got != loadCodeAssistUserAgent { + t.Fatalf("User-Agent = %q, want %q", got, loadCodeAssistUserAgent) } - if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" { - t.Fatalf("X-Goog-Api-Client = %q, want %q", got, "gl-node/22.21.1") + if got := req.Header.Get("X-Goog-Api-Client"); got != "" { + t.Fatalf("X-Goog-Api-Client = %q, want empty", got) } body, _ := io.ReadAll(req.Body) _ = req.Body.Close() - if string(body) != `{"metadata":{"ide_name":"antigravity","ide_type":"ANTIGRAVITY","ide_version":"1.23.2"}}` { + if string(body) != `{"metadata":{"ideType":"ANTIGRAVITY"}}` { t.Fatalf("loadCodeAssist body = %s", string(body)) } return &http.Response{ diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 0a947b20f..73743df4e 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -177,12 +177,15 @@ waitForCallback: if accessToken != "" { fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) if errProject != nil { - log.Warnf("antigravity: failed to fetch project ID: %v", errProject) + return nil, fmt.Errorf("antigravity: failed to fetch project ID: %w", errProject) } else { projectID = fetchedProjectID - log.Infof("antigravity: obtained project ID %s", projectID) + log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID)) } } + if strings.TrimSpace(projectID) == "" { + return nil, fmt.Errorf("antigravity: project ID discovery returned empty project") + } now := time.Now() metadata := map[string]any{ @@ -208,7 +211,7 @@ waitForCallback: fmt.Println("Antigravity authentication successful") if projectID != "" { - fmt.Printf("Using GCP project: %s\n", projectID) + fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID)) } return &coreauth.Auth{ ID: fileName, diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 537f182ac..dfa165b5d 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -45,6 +45,13 @@ type ProviderExecutor interface { HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) } +// RequestAuthPreparer lets an executor update missing auth metadata immediately +// before a request. Manager serializes and persists returned updates. +type RequestAuthPreparer interface { + ShouldPrepareRequestAuth(auth *Auth) bool + PrepareRequestAuth(ctx context.Context, auth *Auth) (*Auth, error) +} + // ExecutionSessionCloser allows executors to release per-session runtime resources. type ExecutionSessionCloser interface { CloseExecutionSession(sessionID string) @@ -182,6 +189,8 @@ type Manager struct { // Auto refresh state refreshCancel context.CancelFunc refreshLoop *authAutoRefreshLoop + + requestPrepareLocks sync.Map } // NewManager constructs a manager with optional custom selector and hook. @@ -1365,6 +1374,17 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req continue } attempted[auth.ID] = struct{}{} + var errPrepare error + auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth) + if errPrepare != nil { + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + m.MarkResult(execCtx, result) + lastErr = errPrepare + continue + } var authErr error for _, upstreamModel := range models { resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled) @@ -1453,6 +1473,17 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, continue } attempted[auth.ID] = struct{}{} + var errPrepare error + auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth) + if errPrepare != nil { + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + m.MarkResult(execCtx, result) + lastErr = errPrepare + continue + } var authErr error for _, upstreamModel := range models { resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled) @@ -1539,6 +1570,17 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string continue } attempted[auth.ID] = struct{}{} + var errPrepare error + auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth) + if errPrepare != nil { + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + m.MarkResult(execCtx, result) + lastErr = errPrepare + continue + } streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { @@ -1630,6 +1672,62 @@ func hasRequestedModelMetadata(meta map[string]any) bool { } } +type requestAuthPrepareLock struct { + mu sync.Mutex +} + +func (m *Manager) prepareRequestAuth(ctx context.Context, executor ProviderExecutor, auth *Auth) (*Auth, error) { + if m == nil || executor == nil || auth == nil { + return auth, nil + } + preparer, ok := executor.(RequestAuthPreparer) + if !ok || preparer == nil || !preparer.ShouldPrepareRequestAuth(auth) { + return auth, nil + } + + id := strings.TrimSpace(auth.ID) + if id == "" { + return preparer.PrepareRequestAuth(ctx, auth.Clone()) + } + + lockValue, _ := m.requestPrepareLocks.LoadOrStore(id, &requestAuthPrepareLock{}) + lock, ok := lockValue.(*requestAuthPrepareLock) + if !ok || lock == nil { + return preparer.PrepareRequestAuth(ctx, auth.Clone()) + } + + lock.mu.Lock() + defer lock.mu.Unlock() + + target := auth.Clone() + m.mu.RLock() + if current := m.auths[id]; current != nil { + target = current.Clone() + } + m.mu.RUnlock() + + if !preparer.ShouldPrepareRequestAuth(target) { + return target, nil + } + + updated, errPrepare := preparer.PrepareRequestAuth(ctx, target) + if errPrepare != nil { + return auth, errPrepare + } + if updated == nil { + return target, nil + } + + saved, errUpdate := m.Update(ctx, updated) + if errUpdate != nil { + return updated, errUpdate + } + if saved != nil { + return saved, nil + } + return updated, nil +} + func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context { alias := requestedModelAliasFromOptions(opts, fallback) ctx = coreusage.WithRequestedModelAlias(ctx, alias) @@ -3667,6 +3765,11 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy } creditsOpts := ensureRequestedModelMetadata(opts, routeModel) creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel) + preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth) + if errPrepare != nil { + continue + } + c.auth = preparedAuth publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID) models := m.executionModelCandidates(c.auth, routeModel) if len(models) == 0 { @@ -3709,6 +3812,11 @@ func (m *Manager) tryAntigravityCreditsExecuteStream(ctx context.Context, req cl creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt) } creditsOpts := ensureRequestedModelMetadata(opts, routeModel) + preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth) + if errPrepare != nil { + continue + } + c.auth = preparedAuth publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID) models := m.executionModelCandidates(c.auth, routeModel) if len(models) == 0 { diff --git a/sdk/cliproxy/auth/request_auth_prepare_test.go b/sdk/cliproxy/auth/request_auth_prepare_test.go new file mode 100644 index 000000000..3c91efb5c --- /dev/null +++ b/sdk/cliproxy/auth/request_auth_prepare_test.go @@ -0,0 +1,146 @@ +package auth + +import ( + "context" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type requestPrepareStore struct { + saveCount atomic.Int32 + mu sync.Mutex + last *Auth +} + +func (s *requestPrepareStore) List(context.Context) ([]*Auth, error) { return nil, nil } + +func (s *requestPrepareStore) Save(_ context.Context, auth *Auth) (string, error) { + s.saveCount.Add(1) + s.mu.Lock() + defer s.mu.Unlock() + s.last = auth.Clone() + return "", nil +} + +func (s *requestPrepareStore) Delete(context.Context, string) error { return nil } + +func (s *requestPrepareStore) lastAuth() *Auth { + s.mu.Lock() + defer s.mu.Unlock() + return s.last.Clone() +} + +type requestPrepareExecutor struct { + prepareCalls atomic.Int32 + executeCalls atomic.Int32 +} + +func (e *requestPrepareExecutor) Identifier() string { return "antigravity" } + +func (e *requestPrepareExecutor) ShouldPrepareRequestAuth(auth *Auth) bool { + return auth == nil || auth.Metadata == nil || testStringValue(auth.Metadata["project_id"]) == "" +} + +func (e *requestPrepareExecutor) PrepareRequestAuth(_ context.Context, auth *Auth) (*Auth, error) { + e.prepareCalls.Add(1) + updated := auth.Clone() + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["project_id"] = "prepared-project" + return updated, nil +} + +func (e *requestPrepareExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.executeCalls.Add(1) + if got := testStringValue(auth.Metadata["project_id"]); got != "prepared-project" { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusBadRequest, Message: "missing prepared project"} + } + return cliproxyexecutor.Response{Payload: []byte("ok")}, nil +} + +func (e *requestPrepareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "stream not implemented"} +} + +func (e *requestPrepareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *requestPrepareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "count not implemented"} +} + +func (e *requestPrepareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "http not implemented"} +} + +func TestManagerExecute_PreparesAndPersistsMissingRequestAuthMetadata(t *testing.T) { + const model = "gemini-3.1-pro" + store := &requestPrepareStore{} + executor := &requestPrepareExecutor{} + manager := NewManager(store, nil, nil) + manager.RegisterExecutor(executor) + + auth := &Auth{ + ID: "auth-request-prepare", + Provider: "antigravity", + Metadata: map[string]any{"access_token": "token"}, + } + if _, errRegister := manager.Register(WithSkipPersist(context.Background()), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, "antigravity", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { registry.GetGlobalRegistry().UnregisterClient(auth.ID) }) + + resp, errExecute := manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{}) + if errExecute != nil { + t.Fatalf("Execute error: %v", errExecute) + } + if string(resp.Payload) != "ok" { + t.Fatalf("payload = %q, want ok", string(resp.Payload)) + } + if got := executor.prepareCalls.Load(); got != 1 { + t.Fatalf("prepare calls = %d, want 1", got) + } + if got := store.saveCount.Load(); got < 1 { + t.Fatalf("save count = %d, want at least 1", got) + } + if got := testStringValue(store.lastAuth().Metadata["project_id"]); got != "prepared-project" { + t.Fatalf("persisted project_id = %q, want prepared-project", got) + } + current, ok := manager.GetByID(auth.ID) + if !ok { + t.Fatal("expected auth in manager") + } + if got := testStringValue(current.Metadata["project_id"]); got != "prepared-project" { + t.Fatalf("manager project_id = %q, want prepared-project", got) + } + + if _, errExecute = manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{}); errExecute != nil { + t.Fatalf("second Execute error: %v", errExecute) + } + if got := executor.prepareCalls.Load(); got != 1 { + t.Fatalf("prepare calls after second execute = %d, want 1", got) + } +} + +func testStringValue(value any) string { + if value == nil { + return "" + } + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + case []byte: + return strings.TrimSpace(string(typed)) + default: + return "" + } +}