From c9dc6bd62803a5de98f70130991040a2c9fbaa5f Mon Sep 17 00:00:00 2001 From: sususu98 Date: Tue, 2 Jun 2026 13:43:07 +0800 Subject: [PATCH] Fix Home auth refresh retry handling Parse Home refresh auth envelopes so refreshed access tokens are used instead of returning missing access token. Stop retrying when Home dispatch returns an auth that already failed within the same request. --- .../runtime/executor/helps/home_refresh.go | 44 ++++++++- .../executor/helps/home_refresh_test.go | 80 ++++++++++++++++ sdk/cliproxy/auth/conductor.go | 31 +++++- sdk/cliproxy/auth/home_retry_loop_test.go | 96 +++++++++++++++++++ 4 files changed, 246 insertions(+), 5 deletions(-) create mode 100644 sdk/cliproxy/auth/home_retry_loop_test.go diff --git a/internal/runtime/executor/helps/home_refresh.go b/internal/runtime/executor/helps/home_refresh.go index dc0270401..7c9719927 100644 --- a/internal/runtime/executor/helps/home_refresh.go +++ b/internal/runtime/executor/helps/home_refresh.go @@ -30,12 +30,26 @@ type homeErrorEnvelope struct { Error *homeErrorDetail `json:"error"` } +type homeRefreshAuthEnvelope struct { + Auth cliproxyauth.Auth `json:"auth"` + AuthIndex string `json:"auth_index"` +} + type homeErrorDetail struct { Type string `json:"type"` Message string `json:"message"` Code string `json:"code,omitempty"` } +type homeRefreshClient interface { + HeartbeatOK() bool + GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, error) +} + +var currentHomeRefreshClient = func() homeRefreshClient { + return home.Current() +} + // RefreshAuthViaHome replaces local refresh logic when home control plane integration is enabled. // It returns (updatedAuth, true, nil) when home refresh succeeds; (nil, true, err) when home is // enabled but refresh fails; and (nil, false, nil) when home is disabled. @@ -50,7 +64,7 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya return nil, true, homeStatusErr{code: http.StatusInternalServerError, msg: "home refresh: auth is nil"} } - client := home.Current() + client := currentHomeRefreshClient() if client == nil || !client.HeartbeatOK() { return nil, true, homeStatusErr{code: http.StatusServiceUnavailable, msg: "home control center unavailable"} } @@ -81,13 +95,35 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya return nil, true, homeStatusErr{code: statusFromHomeErrorCode(code), msg: msg} } - var updated cliproxyauth.Auth - if errUnmarshal := json.Unmarshal(raw, &updated); errUnmarshal != nil { + updated, returnedIndex, errParse := parseHomeRefreshAuth(raw) + if errParse != nil { return nil, true, homeStatusErr{code: http.StatusBadGateway, msg: "home returned invalid auth payload"} } + if returnedIndex != "" { + authIndex = returnedIndex + } updated.Index = authIndex updated.EnsureIndex() - return &updated, true, nil + return updated, true, nil +} + +func parseHomeRefreshAuth(raw []byte) (*cliproxyauth.Auth, string, error) { + var rawObject map[string]json.RawMessage + if errUnmarshal := json.Unmarshal(raw, &rawObject); errUnmarshal != nil { + return nil, "", errUnmarshal + } + if _, ok := rawObject["auth"]; ok { + var envelope homeRefreshAuthEnvelope + if errUnmarshal := json.Unmarshal(raw, &envelope); errUnmarshal != nil { + return nil, "", errUnmarshal + } + return &envelope.Auth, strings.TrimSpace(envelope.AuthIndex), nil + } + var updated cliproxyauth.Auth + if errUnmarshal := json.Unmarshal(raw, &updated); errUnmarshal != nil { + return nil, "", errUnmarshal + } + return &updated, "", nil } func statusFromHomeErrorCode(code string) int { diff --git a/internal/runtime/executor/helps/home_refresh_test.go b/internal/runtime/executor/helps/home_refresh_test.go index c4507fdcc..e87c2b415 100644 --- a/internal/runtime/executor/helps/home_refresh_test.go +++ b/internal/runtime/executor/helps/home_refresh_test.go @@ -1,8 +1,14 @@ package helps import ( + "context" + "encoding/json" "net/http" + "sync/atomic" "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) func TestStatusFromHomeErrorCodeMapsAuthenticationErrorToUnauthorized(t *testing.T) { @@ -13,3 +19,77 @@ func TestStatusFromHomeErrorCodeMapsAuthenticationErrorToUnauthorized(t *testing t.Fatalf("statusFromHomeErrorCode(unauthorized) = %d, want %d", got, http.StatusUnauthorized) } } + +type fakeHomeRefreshClient struct { + calls atomic.Int32 + authIndex string + raw []byte +} + +func (c *fakeHomeRefreshClient) HeartbeatOK() bool { + return true +} + +func (c *fakeHomeRefreshClient) GetRefreshAuth(_ context.Context, authIndex string) ([]byte, error) { + c.calls.Add(1) + c.authIndex = authIndex + return c.raw, nil +} + +func TestRefreshAuthViaHomeAcceptsAuthEnvelope(t *testing.T) { + raw, errMarshal := json.Marshal(struct { + Auth cliproxyauth.Auth `json:"auth"` + AuthIndex string `json:"auth_index"` + }{ + Auth: cliproxyauth.Auth{ + ID: "home-auth-1", + Provider: "antigravity", + Metadata: map[string]any{ + "access_token": "new-access-token", + }, + }, + AuthIndex: "home-index-1", + }) + if errMarshal != nil { + t.Fatalf("marshal home envelope: %v", errMarshal) + } + + client := &fakeHomeRefreshClient{raw: raw} + oldCurrentHomeRefreshClient := currentHomeRefreshClient + currentHomeRefreshClient = func() homeRefreshClient { + return client + } + t.Cleanup(func() { + currentHomeRefreshClient = oldCurrentHomeRefreshClient + }) + + cfg := &config.Config{Home: config.HomeConfig{Enabled: true}} + auth := &cliproxyauth.Auth{ + ID: "home-auth-1", + Provider: "antigravity", + Index: "home-index-1", + Metadata: map[string]any{ + "refresh_token": "refresh-token", + }, + } + + updated, handled, err := RefreshAuthViaHome(context.Background(), cfg, auth) + if err != nil { + t.Fatalf("RefreshAuthViaHome error: %v", err) + } + if !handled { + t.Fatal("RefreshAuthViaHome handled = false, want true") + } + if got := client.calls.Load(); got != 1 { + t.Fatalf("home refresh calls = %d, want 1", got) + } + if client.authIndex != "home-index-1" { + t.Fatalf("home refresh auth_index = %q, want home-index-1", client.authIndex) + } + if updated == nil { + t.Fatal("updated auth = nil") + } + if got := updated.Metadata["access_token"]; got != "new-access-token" { + t.Fatalf("updated access_token = %q, want new-access-token", got) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 33116fba8..c5c7e3f94 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -3368,6 +3368,23 @@ func shouldReturnLastErrorOnPickFailure(homeMode bool, lastErr error, errPick er return isHomeRequestRetryExceededError(errPick) } +func homeAuthAlreadyTried(tried map[string]struct{}, authID string) bool { + authID = strings.TrimSpace(authID) + if authID == "" || len(tried) == 0 { + return false + } + _, ok := tried[authID] + return ok +} + +func repeatedHomeAuthError() *Error { + return &Error{ + Code: homeRequestRetryExceededErrorCode, + Message: "home returned a previously tried auth", + HTTPStatus: http.StatusServiceUnavailable, + } +} + type homeAuthDispatchResponse struct { Model string `json:"model"` Provider string `json:"provider"` @@ -3376,6 +3393,15 @@ type homeAuthDispatchResponse struct { Auth Auth `json:"auth"` } +type homeAuthDispatcher interface { + HeartbeatOK() bool + RPopAuth(ctx context.Context, requestedModel string, sessionID string, headers http.Header, count int) ([]byte, error) +} + +var currentHomeDispatcher = func() homeAuthDispatcher { + return home.Current() +} + func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) { apiKey = strings.TrimSpace(apiKey) if apiKey == "" || ctx == nil { @@ -3575,7 +3601,7 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro } } - client := home.Current() + client := currentHomeDispatcher() if client == nil || !client.HeartbeatOK() { return nil, nil, "", &Error{Code: "home_unavailable", Message: "home control center unavailable", HTTPStatus: http.StatusServiceUnavailable} } @@ -3630,6 +3656,9 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro if strings.TrimSpace(auth.ID) == "" { return nil, nil, "", &Error{Code: "invalid_auth", Message: "home returned auth without id", HTTPStatus: http.StatusBadGateway} } + if homeAuthAlreadyTried(tried, auth.ID) { + return nil, nil, "", repeatedHomeAuthError() + } providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) if providerKey == "" { return nil, nil, "", &Error{Code: "invalid_auth", Message: "home returned auth without provider", HTTPStatus: http.StatusBadGateway} diff --git a/sdk/cliproxy/auth/home_retry_loop_test.go b/sdk/cliproxy/auth/home_retry_loop_test.go new file mode 100644 index 000000000..16f6e824b --- /dev/null +++ b/sdk/cliproxy/auth/home_retry_loop_test.go @@ -0,0 +1,96 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "sync/atomic" + "testing" + "time" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" +) + +type repeatedHomeAuthDispatcher struct { + calls atomic.Int32 +} + +func (d *repeatedHomeAuthDispatcher) HeartbeatOK() bool { + return true +} + +func (d *repeatedHomeAuthDispatcher) RPopAuth(context.Context, string, string, http.Header, int) ([]byte, error) { + d.calls.Add(1) + raw, _ := json.Marshal(homeAuthDispatchResponse{ + Auth: Auth{ + ID: "home-auth-1", + Provider: "home-loop-test", + Status: StatusActive, + Metadata: map[string]any{"email": "loop@example.com"}, + }, + }) + return raw, nil +} + +type unauthorizedHomeExecutor struct { + calls atomic.Int32 +} + +func (e *unauthorizedHomeExecutor) Identifier() string { return "home-loop-test" } + +func (e *unauthorizedHomeExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.calls.Add(1) + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + e.calls.Add(1) + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) Refresh(context.Context, *Auth) (*Auth, error) { + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.calls.Add(1) + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func (e *unauthorizedHomeExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, &Error{HTTPStatus: http.StatusUnauthorized, Message: "missing access token"} +} + +func TestManagerExecuteHomeStopsWhenDispatchRepeatsTriedAuth(t *testing.T) { + dispatcher := &repeatedHomeAuthDispatcher{} + oldCurrentHomeDispatcher := currentHomeDispatcher + currentHomeDispatcher = func() homeAuthDispatcher { + return dispatcher + } + t.Cleanup(func() { + currentHomeDispatcher = oldCurrentHomeDispatcher + }) + + executor := &unauthorizedHomeExecutor{} + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(executor) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := manager.Execute(ctx, []string{"home-loop-test"}, cliproxyexecutor.Request{Model: "gemini-3.5-flash-low"}, cliproxyexecutor.Options{}) + if err == nil { + t.Fatal("Execute error = nil, want missing access token") + } + if statusCodeFromError(err) != http.StatusUnauthorized { + t.Fatalf("Execute error status = %d, want 401 (%v)", statusCodeFromError(err), err) + } + if got := executor.calls.Load(); got != 1 { + t.Fatalf("executor calls = %d, want 1", got) + } + if got := dispatcher.calls.Load(); got != 2 { + t.Fatalf("home dispatch calls = %d, want 2", got) + } +}