mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-18 07:43:24 +08:00
- Introduced `singleflight.Group` to prevent redundant token refresh calls across multiple auth implementations (`antigravity`, `kimi`, `xai`, `codex`). - Added tests to verify shared upstream calls during concurrent refresh requests. - Refactored token refresh logic to centralize and standardize deduplication mechanisms.
153 lines
4.6 KiB
Go
153 lines
4.6 KiB
Go
package codex
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return f(req)
|
|
}
|
|
|
|
func resetCodexRefreshGroupForTest() {
|
|
codexRefreshGroup = singleflight.Group{}
|
|
}
|
|
|
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
|
var calls int32
|
|
auth := &CodexAuth{
|
|
httpClient: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
atomic.AddInt32(&calls, 1)
|
|
return &http.Response{
|
|
StatusCode: http.StatusBadRequest,
|
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
|
Header: make(http.Header),
|
|
Request: req,
|
|
}, nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
|
if err == nil {
|
|
t.Fatalf("expected error for non-retryable refresh failure")
|
|
}
|
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
|
}
|
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestRefreshTokens_DeduplicatesConcurrentRefreshAcrossInstances(t *testing.T) {
|
|
resetCodexRefreshGroupForTest()
|
|
t.Cleanup(resetCodexRefreshGroupForTest)
|
|
|
|
var calls int32
|
|
started := make(chan struct{})
|
|
release := make(chan struct{})
|
|
var once sync.Once
|
|
|
|
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
atomic.AddInt32(&calls, 1)
|
|
once.Do(func() { close(started) })
|
|
<-release
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader(`{
|
|
"access_token":"new-access",
|
|
"refresh_token":"new-refresh",
|
|
"token_type":"Bearer",
|
|
"expires_in":3600
|
|
}`)),
|
|
Header: make(http.Header),
|
|
Request: req,
|
|
}, nil
|
|
})
|
|
authA := &CodexAuth{httpClient: &http.Client{Transport: transport}}
|
|
authB := &CodexAuth{httpClient: &http.Client{Transport: transport}}
|
|
|
|
results := make(chan *CodexTokenData, 2)
|
|
errs := make(chan error, 2)
|
|
runRefresh := func(auth *CodexAuth, launched chan<- struct{}) {
|
|
if launched != nil {
|
|
close(launched)
|
|
}
|
|
tokenData, errRefresh := auth.RefreshTokens(context.Background(), "shared-refresh-token")
|
|
results <- tokenData
|
|
errs <- errRefresh
|
|
}
|
|
|
|
go runRefresh(authA, nil)
|
|
<-started
|
|
|
|
secondLaunched := make(chan struct{})
|
|
go runRefresh(authB, secondLaunched)
|
|
<-secondLaunched
|
|
time.Sleep(20 * time.Millisecond)
|
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
|
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
|
|
}
|
|
close(release)
|
|
|
|
for i := 0; i < 2; i++ {
|
|
if errRefresh := <-errs; errRefresh != nil {
|
|
t.Fatalf("expected refresh to succeed, got %v", errRefresh)
|
|
}
|
|
tokenData := <-results
|
|
if tokenData == nil || tokenData.AccessToken != "new-access" || tokenData.RefreshToken != "new-refresh" {
|
|
t.Fatalf("unexpected token data: %#v", tokenData)
|
|
}
|
|
}
|
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
|
t.Fatalf("expected both refresh callers to share a single upstream call, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestNewCodexAuthWithProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
|
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
|
|
auth := NewCodexAuthWithProxyURL(cfg, "direct")
|
|
|
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
|
if !ok || transport == nil {
|
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
|
}
|
|
if transport.Proxy != nil {
|
|
t.Fatal("expected direct transport to disable proxy function")
|
|
}
|
|
}
|
|
|
|
func TestNewCodexAuthWithProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
|
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
|
|
auth := NewCodexAuthWithProxyURL(cfg, "http://override.example.com:8081")
|
|
|
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
|
if !ok || transport == nil {
|
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
|
}
|
|
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
|
if errReq != nil {
|
|
t.Fatalf("new request: %v", errReq)
|
|
}
|
|
proxyURL, errProxy := transport.Proxy(req)
|
|
if errProxy != nil {
|
|
t.Fatalf("proxy func: %v", errProxy)
|
|
}
|
|
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
|
|
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
|
|
}
|
|
}
|