Files
CLIProxyAPI/internal/auth/codex/openai_auth_test.go
Luis Pater 8e52c403f7 feat(auth): deduplicate concurrent refresh token requests with singleflight
- Introduced `singleflight.Group` to prevent redundant token refresh calls across multiple auth implementations (`antigravity`, `kimi`, `xai`, `codex`).
- Added tests to verify shared upstream calls during concurrent refresh requests.
- Refactored token refresh logic to centralize and standardize deduplication mechanisms.
2026-06-10 03:19:26 +08:00

153 lines
4.6 KiB
Go

package codex
import (
"context"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"golang.org/x/sync/singleflight"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func resetCodexRefreshGroupForTest() {
codexRefreshGroup = singleflight.Group{}
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}
func TestRefreshTokens_DeduplicatesConcurrentRefreshAcrossInstances(t *testing.T) {
resetCodexRefreshGroupForTest()
t.Cleanup(resetCodexRefreshGroupForTest)
var calls int32
started := make(chan struct{})
release := make(chan struct{})
var once sync.Once
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
once.Do(func() { close(started) })
<-release
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{
"access_token":"new-access",
"refresh_token":"new-refresh",
"token_type":"Bearer",
"expires_in":3600
}`)),
Header: make(http.Header),
Request: req,
}, nil
})
authA := &CodexAuth{httpClient: &http.Client{Transport: transport}}
authB := &CodexAuth{httpClient: &http.Client{Transport: transport}}
results := make(chan *CodexTokenData, 2)
errs := make(chan error, 2)
runRefresh := func(auth *CodexAuth, launched chan<- struct{}) {
if launched != nil {
close(launched)
}
tokenData, errRefresh := auth.RefreshTokens(context.Background(), "shared-refresh-token")
results <- tokenData
errs <- errRefresh
}
go runRefresh(authA, nil)
<-started
secondLaunched := make(chan struct{})
go runRefresh(authB, secondLaunched)
<-secondLaunched
time.Sleep(20 * time.Millisecond)
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
}
close(release)
for i := 0; i < 2; i++ {
if errRefresh := <-errs; errRefresh != nil {
t.Fatalf("expected refresh to succeed, got %v", errRefresh)
}
tokenData := <-results
if tokenData == nil || tokenData.AccessToken != "new-access" || tokenData.RefreshToken != "new-refresh" {
t.Fatalf("unexpected token data: %#v", tokenData)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected both refresh callers to share a single upstream call, got %d", got)
}
}
func TestNewCodexAuthWithProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
auth := NewCodexAuthWithProxyURL(cfg, "direct")
transport, ok := auth.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestNewCodexAuthWithProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
auth := NewCodexAuthWithProxyURL(cfg, "http://override.example.com:8081")
transport, ok := auth.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
}
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errReq != nil {
t.Fatalf("new request: %v", errReq)
}
proxyURL, errProxy := transport.Proxy(req)
if errProxy != nil {
t.Fatalf("proxy func: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
}
}