diff --git a/internal/runtime/executor/helps/home_refresh.go b/internal/runtime/executor/helps/home_refresh.go index dc0270401..7c9719927 100644 --- a/internal/runtime/executor/helps/home_refresh.go +++ b/internal/runtime/executor/helps/home_refresh.go @@ -30,12 +30,26 @@ type homeErrorEnvelope struct { Error *homeErrorDetail `json:"error"` } +type homeRefreshAuthEnvelope struct { + Auth cliproxyauth.Auth `json:"auth"` + AuthIndex string `json:"auth_index"` +} + type homeErrorDetail struct { Type string `json:"type"` Message string `json:"message"` Code string `json:"code,omitempty"` } +type homeRefreshClient interface { + HeartbeatOK() bool + GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, error) +} + +var currentHomeRefreshClient = func() homeRefreshClient { + return home.Current() +} + // RefreshAuthViaHome replaces local refresh logic when home control plane integration is enabled. // It returns (updatedAuth, true, nil) when home refresh succeeds; (nil, true, err) when home is // enabled but refresh fails; and (nil, false, nil) when home is disabled. @@ -50,7 +64,7 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya return nil, true, homeStatusErr{code: http.StatusInternalServerError, msg: "home refresh: auth is nil"} } - client := home.Current() + client := currentHomeRefreshClient() if client == nil || !client.HeartbeatOK() { return nil, true, homeStatusErr{code: http.StatusServiceUnavailable, msg: "home control center unavailable"} } @@ -81,13 +95,35 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya return nil, true, homeStatusErr{code: statusFromHomeErrorCode(code), msg: msg} } - var updated cliproxyauth.Auth - if errUnmarshal := json.Unmarshal(raw, &updated); errUnmarshal != nil { + updated, returnedIndex, errParse := parseHomeRefreshAuth(raw) + if errParse != nil { return nil, true, homeStatusErr{code: http.StatusBadGateway, msg: "home returned invalid auth payload"} } + if returnedIndex != "" { + authIndex = returnedIndex + } updated.Index = authIndex updated.EnsureIndex() - return &updated, true, nil + return updated, true, nil +} + +func parseHomeRefreshAuth(raw []byte) (*cliproxyauth.Auth, string, error) { + var rawObject map[string]json.RawMessage + if errUnmarshal := json.Unmarshal(raw, &rawObject); errUnmarshal != nil { + return nil, "", errUnmarshal + } + if _, ok := rawObject["auth"]; ok { + var envelope homeRefreshAuthEnvelope + if errUnmarshal := json.Unmarshal(raw, &envelope); errUnmarshal != nil { + return nil, "", errUnmarshal + } + return &envelope.Auth, strings.TrimSpace(envelope.AuthIndex), nil + } + var updated cliproxyauth.Auth + if errUnmarshal := json.Unmarshal(raw, &updated); errUnmarshal != nil { + return nil, "", errUnmarshal + } + return &updated, "", nil } func statusFromHomeErrorCode(code string) int { diff --git a/internal/runtime/executor/helps/home_refresh_test.go b/internal/runtime/executor/helps/home_refresh_test.go index c4507fdcc..e87c2b415 100644 --- a/internal/runtime/executor/helps/home_refresh_test.go +++ b/internal/runtime/executor/helps/home_refresh_test.go @@ -1,8 +1,14 @@ package helps import ( + "context" + "encoding/json" "net/http" + "sync/atomic" "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) func TestStatusFromHomeErrorCodeMapsAuthenticationErrorToUnauthorized(t *testing.T) { @@ -13,3 +19,77 @@ func TestStatusFromHomeErrorCodeMapsAuthenticationErrorToUnauthorized(t *testing t.Fatalf("statusFromHomeErrorCode(unauthorized) = %d, want %d", got, http.StatusUnauthorized) } } + +type fakeHomeRefreshClient struct { + calls atomic.Int32 + authIndex string + raw []byte +} + +func (c *fakeHomeRefreshClient) HeartbeatOK() bool { + return true +} + +func (c *fakeHomeRefreshClient) GetRefreshAuth(_ context.Context, authIndex string) ([]byte, error) { + c.calls.Add(1) + c.authIndex = authIndex + return c.raw, nil +} + +func TestRefreshAuthViaHomeAcceptsAuthEnvelope(t *testing.T) { + raw, errMarshal := json.Marshal(struct { + Auth cliproxyauth.Auth `json:"auth"` + AuthIndex string `json:"auth_index"` + }{ + Auth: cliproxyauth.Auth{ + ID: "home-auth-1", + Provider: "antigravity", + Metadata: map[string]any{ + "access_token": "new-access-token", + }, + }, + AuthIndex: "home-index-1", + }) + if errMarshal != nil { + t.Fatalf("marshal home envelope: %v", errMarshal) + } + + client := &fakeHomeRefreshClient{raw: raw} + oldCurrentHomeRefreshClient := currentHomeRefreshClient + currentHomeRefreshClient = func() homeRefreshClient { + return client + } + t.Cleanup(func() { + currentHomeRefreshClient = oldCurrentHomeRefreshClient + }) + + cfg := &config.Config{Home: config.HomeConfig{Enabled: true}} + auth := &cliproxyauth.Auth{ + ID: "home-auth-1", + Provider: "antigravity", + Index: "home-index-1", + Metadata: map[string]any{ + "refresh_token": "refresh-token", + }, + } + + updated, handled, err := RefreshAuthViaHome(context.Background(), cfg, auth) + if err != nil { + t.Fatalf("RefreshAuthViaHome error: %v", err) + } + if !handled { + t.Fatal("RefreshAuthViaHome handled = false, want true") + } + if got := client.calls.Load(); got != 1 { + t.Fatalf("home refresh calls = %d, want 1", got) + } + if client.authIndex != "home-index-1" { + t.Fatalf("home refresh auth_index = %q, want home-index-1", client.authIndex) + } + if updated == nil { + t.Fatal("updated auth = nil") + } + if got := updated.Metadata["access_token"]; got != "new-access-token" { + t.Fatalf("updated access_token = %q, want new-access-token", got) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 33116fba8..c5c7e3f94 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -3368,6 +3368,23 @@ func shouldReturnLastErrorOnPickFailure(homeMode bool, lastErr error, errPick er return isHomeRequestRetryExceededError(errPick) } +func homeAuthAlreadyTried(tried map[string]struct{}, authID string) bool { + authID = strings.TrimSpace(authID) + if authID == "" || len(tried) == 0 { + return false + } + _, ok := tried[authID] + return ok +} + +func repeatedHomeAuthError() *Error { + return &Error{ + Code: homeRequestRetryExceededErrorCode, + Message: "home returned a previously tried auth", + HTTPStatus: http.StatusServiceUnavailable, + } +} + type homeAuthDispatchResponse struct { Model string `json:"model"` Provider string `json:"provider"` @@ -3376,6 +3393,15 @@ type homeAuthDispatchResponse struct { Auth Auth `json:"auth"` } +type homeAuthDispatcher interface { + HeartbeatOK() bool + RPopAuth(ctx context.Context, requestedModel string, sessionID string, headers http.Header, count int) ([]byte, error) +} + +var currentHomeDispatcher = func() homeAuthDispatcher { + return home.Current() +} + func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) { apiKey = strings.TrimSpace(apiKey) if apiKey == "" || ctx == nil { @@ -3575,7 +3601,7 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro } } - client := home.Current() + client := currentHomeDispatcher() if client == nil || !client.HeartbeatOK() { return nil, nil, "", &Error{Code: "home_unavailable", Message: "home control center unavailable", HTTPStatus: http.StatusServiceUnavailable} } @@ -3630,6 +3656,9 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro if strings.TrimSpace(auth.ID) == "" { return nil, nil, "", &Error{Code: "invalid_auth", Message: "home returned auth without id", HTTPStatus: http.StatusBadGateway} } + if homeAuthAlreadyTried(tried, auth.ID) { + return nil, nil, "", repeatedHomeAuthError() + } providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) if providerKey == "" { return nil, nil, "", &Error{Code: "invalid_auth", Message: "home returned auth without provider", HTTPStatus: http.StatusBadGateway} diff --git a/sdk/cliproxy/auth/home_retry_loop_test.go b/sdk/cliproxy/auth/home_retry_loop_test.go new file mode 100644 index 000000000..16f6e824b --- /dev/null +++ b/sdk/cliproxy/auth/home_retry_loop_test.go @@ -0,0 +1,96 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "sync/atomic" + "testing" + "time" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" +) + +type repeatedHomeAuthDispatcher struct { + calls atomic.Int32 +} + +func (d *repeatedHomeAuthDispatcher) HeartbeatOK() bool { + return true +} + +func (d *repeatedHomeAuthDispatcher) RPopAuth(context.Context, string, string, http.Header, int) ([]byte, error) { + d.calls.Add(1) + raw, _ := json.Marshal(homeAuthDispatchResponse{ + Auth: Auth{ + ID: "home-auth-1", + Provider: "home-loop-test", + Status: StatusActive, + Metadata: map[string]any{"email": "loop@example.com"}, + }, + }) + return raw, nil +} + +type unauthorizedHomeExecutor struct { + calls atomic.Int32 +} + +func (e *unauthorizedHomeExecutor) Identifier() string { return "home-loop-test" } + +func (e *unauthorizedHomeExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.calls.Add(1) + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + e.calls.Add(1) + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) Refresh(context.Context, *Auth) (*Auth, error) { + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.calls.Add(1) + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func TestManagerExecuteHomeStopsWhenDispatchRepeatsTriedAuth(t *testing.T) { + dispatcher := &repeatedHomeAuthDispatcher{} + oldCurrentHomeDispatcher := currentHomeDispatcher + currentHomeDispatcher = func() homeAuthDispatcher { + return dispatcher + } + t.Cleanup(func() { + currentHomeDispatcher = oldCurrentHomeDispatcher + }) + + executor := &unauthorizedHomeExecutor{} + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(executor) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := manager.Execute(ctx, []string{"home-loop-test"}, cliproxyexecutor.Request{Model: "gemini-3.5-flash-low"}, cliproxyexecutor.Options{}) + if err == nil { + t.Fatal("Execute error = nil, want missing access token") + } + if statusCodeFromError(err) != http.StatusUnauthorized { + t.Fatalf("Execute error status = %d, want 401 (%v)", statusCodeFromError(err), err) + } + if got := executor.calls.Load(); got != 1 { + t.Fatalf("executor calls = %d, want 1", got) + } + if got := dispatcher.calls.Load(); got != 2 { + t.Fatalf("home dispatch calls = %d, want 2", got) + } +}