diff --git a/README.md b/README.md index 82617c9db..393ff63cf 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,6 @@ VisionCoder is also offering our users a limited-time ", ...]}. -// If "value" is an empty array, clears all entries. -// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change. -func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) { - var body struct { - Value []string `json:"value"` - } - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - - if body.Value == nil { - c.JSON(400, gin.H{"error": "missing value"}) - return - } - - // Empty array means clear all - if len(body.Value) == 0 { - h.cfg.AmpCode.UpstreamAPIKeys = nil - h.persist(c) - return - } - - toRemove := make(map[string]bool) - for _, key := range body.Value { - trimmed := strings.TrimSpace(key) - if trimmed == "" { - continue - } - toRemove[trimmed] = true - } - if len(toRemove) == 0 { - c.JSON(400, gin.H{"error": "empty value"}) - return - } - - newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys)) - for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys { - if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] { - newEntries = append(newEntries, entry) - } - } - h.cfg.AmpCode.UpstreamAPIKeys = newEntries - h.persist(c) -} - -// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries. -func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry { - if len(entries) == 0 { - return nil - } - out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries)) - for _, entry := range entries { - upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey) - if upstreamKey == "" { - continue - } - apiKeys := normalizeAPIKeysList(entry.APIKeys) - out = append(out, config.AmpUpstreamAPIKeyEntry{ - UpstreamAPIKey: upstreamKey, - APIKeys: apiKeys, - }) - } - if len(out) == 0 { - return nil - } - return out -} - -// normalizeAPIKeysList trims and filters empty strings from a list of API keys. -func normalizeAPIKeysList(keys []string) []string { - if len(keys) == 0 { - return nil - } - out := make([]string, 0, len(keys)) - for _, k := range keys { - trimmed := strings.TrimSpace(k) - if trimmed != "" { - out = append(out, trimmed) - } - } - if len(out) == 0 { - return nil - } - return out -} diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index 0ee849ae4..7108390b5 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -241,9 +241,5 @@ func shouldLogRequest(path string) bool { return false } - if strings.HasPrefix(path, "/api") { - return strings.HasPrefix(path, "/api/provider") - } - return true } diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go deleted file mode 100644 index 18c8ac1ef..000000000 --- a/internal/api/modules/amp/amp.go +++ /dev/null @@ -1,427 +0,0 @@ -// Package amp implements the Amp CLI routing module, providing OAuth-based -// integration with Amp CLI for ChatGPT and Anthropic subscriptions. -package amp - -import ( - "fmt" - "net/http/httputil" - "strings" - "sync" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/api/modules" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - sdkaccess "github.com/router-for-me/CLIProxyAPI/v7/sdk/access" - log "github.com/sirupsen/logrus" -) - -// Option configures the AmpModule. -type Option func(*AmpModule) - -// AmpModule implements the RouteModuleV2 interface for Amp CLI integration. -// It provides: -// - Reverse proxy to Amp control plane for OAuth/management -// - Provider-specific route aliases (/api/provider/{provider}/...) -// - Automatic gzip decompression for misconfigured upstreams -// - Model mapping for routing unavailable models to alternatives -type AmpModule struct { - secretSource SecretSource - proxy *httputil.ReverseProxy - proxyMu sync.RWMutex // protects proxy for hot-reload - accessManager *sdkaccess.Manager - authMiddleware_ gin.HandlerFunc - modelMapper *DefaultModelMapper - enabled bool - registerOnce sync.Once - - // restrictToLocalhost controls localhost-only access for management routes (hot-reloadable) - restrictToLocalhost bool - restrictMu sync.RWMutex - - // configMu protects lastConfig for partial reload comparison - configMu sync.RWMutex - lastConfig *config.AmpCode -} - -// New creates a new Amp routing module with the given options. -// This is the preferred constructor using the Option pattern. -// -// Example: -// -// ampModule := amp.New( -// amp.WithAccessManager(accessManager), -// amp.WithAuthMiddleware(authMiddleware), -// amp.WithSecretSource(customSecret), -// ) -func New(opts ...Option) *AmpModule { - m := &AmpModule{ - secretSource: nil, // Will be created on demand if not provided - } - for _, opt := range opts { - opt(m) - } - return m -} - -// NewLegacy creates a new Amp routing module using the legacy constructor signature. -// This is provided for backwards compatibility. -// -// DEPRECATED: Use New with options instead. -func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule { - return New( - WithAccessManager(accessManager), - WithAuthMiddleware(authMiddleware), - ) -} - -// WithSecretSource sets a custom secret source for the module. -func WithSecretSource(source SecretSource) Option { - return func(m *AmpModule) { - m.secretSource = source - } -} - -// WithAccessManager sets the access manager for the module. -func WithAccessManager(am *sdkaccess.Manager) Option { - return func(m *AmpModule) { - m.accessManager = am - } -} - -// WithAuthMiddleware sets the authentication middleware for provider routes. -func WithAuthMiddleware(middleware gin.HandlerFunc) Option { - return func(m *AmpModule) { - m.authMiddleware_ = middleware - } -} - -// Name returns the module identifier -func (m *AmpModule) Name() string { - return "amp-routing" -} - -// forceModelMappings returns whether model mappings should take precedence over local API keys -func (m *AmpModule) forceModelMappings() bool { - m.configMu.RLock() - defer m.configMu.RUnlock() - if m.lastConfig == nil { - return false - } - return m.lastConfig.ForceModelMappings -} - -// Register sets up Amp routes if configured. -// This implements the RouteModuleV2 interface with Context. -// Routes are registered only once via sync.Once for idempotent behavior. -func (m *AmpModule) Register(ctx modules.Context) error { - settings := ctx.Config.AmpCode - upstreamURL := strings.TrimSpace(settings.UpstreamURL) - - // Determine auth middleware (from module or context) - auth := m.getAuthMiddleware(ctx) - - // Use registerOnce to ensure routes are only registered once - var regErr error - m.registerOnce.Do(func() { - // Initialize model mapper from config (for routing unavailable models to alternatives) - m.modelMapper = NewModelMapper(settings.ModelMappings) - - // Store initial config for partial reload comparison - m.lastConfig = new(settings) - - // Initialize localhost restriction setting (hot-reloadable) - m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost) - - // Always register provider aliases - these work without an upstream - m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) - - // Register management proxy routes once; middleware will gate access when upstream is unavailable. - // Pass auth middleware to require valid API key for all management routes. - m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth) - - // If no upstream URL, skip proxy routes but provider aliases are still available - if upstreamURL == "" { - log.Debug("amp upstream proxy disabled (no upstream URL configured)") - log.Debug("amp provider alias routes registered") - m.enabled = false - return - } - - if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil { - regErr = fmt.Errorf("failed to create amp proxy: %w", err) - return - } - - log.Debug("amp provider alias routes registered") - }) - - return regErr -} - -// getAuthMiddleware returns the authentication middleware, preferring the -// module's configured middleware, then the context middleware, then a fallback. -func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { - if m.authMiddleware_ != nil { - return m.authMiddleware_ - } - if ctx.AuthMiddleware != nil { - return ctx.AuthMiddleware - } - // Fallback: no authentication (should not happen in production) - log.Warn("amp module: no auth middleware provided, allowing all requests") - return func(c *gin.Context) { - c.Next() - } -} - -// OnConfigUpdated handles configuration updates with partial reload support. -// Only updates components that have actually changed to avoid unnecessary work. -// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost. -func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { - newSettings := cfg.AmpCode - - // Get previous config for comparison - m.configMu.RLock() - oldSettings := m.lastConfig - m.configMu.RUnlock() - - if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost { - m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost) - } - - newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL) - oldUpstreamURL := "" - if oldSettings != nil { - oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL) - } - - if !m.enabled && newUpstreamURL != "" { - if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil { - log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err) - } - } - - // Check model mappings change - modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings) - if modelMappingsChanged { - if m.modelMapper != nil { - m.modelMapper.UpdateMappings(newSettings.ModelMappings) - } else if m.enabled { - log.Warnf("amp model mapper not initialized, skipping model mapping update") - } - } - - if m.enabled { - // Check upstream URL change - now supports hot-reload - if newUpstreamURL == "" && oldUpstreamURL != "" { - m.setProxy(nil) - m.enabled = false - } else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" { - // Recreate proxy with new URL - proxy, err := createReverseProxy(newUpstreamURL, m.secretSource) - if err != nil { - log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err) - } else { - m.setProxy(proxy) - } - } - - // Check API key change (both default and per-client mappings) - apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) - upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings) - if apiKeyChanged || upstreamAPIKeysChanged { - if m.secretSource != nil { - if ms, ok := m.secretSource.(*MappedSecretSource); ok { - if apiKeyChanged { - ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey) - ms.InvalidateCache() - } - if upstreamAPIKeysChanged { - ms.UpdateMappings(newSettings.UpstreamAPIKeys) - } - } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { - ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) - ms.InvalidateCache() - } - } - } - - } - - // Store current config for next comparison - m.configMu.Lock() - settingsCopy := newSettings // copy struct - m.lastConfig = &settingsCopy - m.configMu.Unlock() - - return nil -} - -func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error { - if m.secretSource == nil { - // Create MultiSourceSecret as the default source, then wrap with MappedSecretSource - defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */) - mappedSource := NewMappedSecretSource(defaultSource) - mappedSource.UpdateMappings(settings.UpstreamAPIKeys) - m.secretSource = mappedSource - } else if ms, ok := m.secretSource.(*MappedSecretSource); ok { - ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey) - ms.InvalidateCache() - ms.UpdateMappings(settings.UpstreamAPIKeys) - } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { - // Legacy path: wrap existing MultiSourceSecret with MappedSecretSource - ms.UpdateExplicitKey(settings.UpstreamAPIKey) - ms.InvalidateCache() - mappedSource := NewMappedSecretSource(ms) - mappedSource.UpdateMappings(settings.UpstreamAPIKeys) - m.secretSource = mappedSource - } - - proxy, err := createReverseProxy(upstreamURL, m.secretSource) - if err != nil { - return err - } - - m.setProxy(proxy) - m.enabled = true - - log.Infof("amp upstream proxy enabled for: %s", upstreamURL) - return nil -} - -// hasModelMappingsChanged compares old and new model mappings. -func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool { - if old == nil { - return len(new.ModelMappings) > 0 - } - - if len(old.ModelMappings) != len(new.ModelMappings) { - return true - } - - // Build map for efficient and robust comparison - type mappingInfo struct { - to string - regex bool - } - oldMap := make(map[string]mappingInfo, len(old.ModelMappings)) - for _, mapping := range old.ModelMappings { - oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{ - to: strings.TrimSpace(mapping.To), - regex: mapping.Regex, - } - } - - for _, mapping := range new.ModelMappings { - from := strings.TrimSpace(mapping.From) - to := strings.TrimSpace(mapping.To) - if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex { - return true - } - } - - return false -} - -// hasAPIKeyChanged compares old and new API keys. -func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool { - oldKey := "" - if old != nil { - oldKey = strings.TrimSpace(old.UpstreamAPIKey) - } - newKey := strings.TrimSpace(new.UpstreamAPIKey) - return oldKey != newKey -} - -// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings. -func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool { - if old == nil { - return len(new.UpstreamAPIKeys) > 0 - } - - if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) { - return true - } - - // Build map for comparison: upstreamKey -> set of clientKeys - type entryInfo struct { - upstreamKey string - clientKeys map[string]struct{} - } - oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys)) - for i, entry := range old.UpstreamAPIKeys { - clientKeys := make(map[string]struct{}, len(entry.APIKeys)) - for _, k := range entry.APIKeys { - trimmed := strings.TrimSpace(k) - if trimmed == "" { - continue - } - clientKeys[trimmed] = struct{}{} - } - oldEntries[i] = entryInfo{ - upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey), - clientKeys: clientKeys, - } - } - - for i, newEntry := range new.UpstreamAPIKeys { - if i >= len(oldEntries) { - return true - } - oldE := oldEntries[i] - if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey { - return true - } - newKeys := make(map[string]struct{}, len(newEntry.APIKeys)) - for _, k := range newEntry.APIKeys { - trimmed := strings.TrimSpace(k) - if trimmed == "" { - continue - } - newKeys[trimmed] = struct{}{} - } - if len(newKeys) != len(oldE.clientKeys) { - return true - } - for k := range newKeys { - if _, ok := oldE.clientKeys[k]; !ok { - return true - } - } - } - - return false -} - -// GetModelMapper returns the model mapper instance (for testing/debugging). -func (m *AmpModule) GetModelMapper() *DefaultModelMapper { - return m.modelMapper -} - -// getProxy returns the current proxy instance (thread-safe for hot-reload). -func (m *AmpModule) getProxy() *httputil.ReverseProxy { - m.proxyMu.RLock() - defer m.proxyMu.RUnlock() - return m.proxy -} - -// setProxy updates the proxy instance (thread-safe for hot-reload). -func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) { - m.proxyMu.Lock() - defer m.proxyMu.Unlock() - m.proxy = proxy -} - -// IsRestrictedToLocalhost returns whether management routes are restricted to localhost. -func (m *AmpModule) IsRestrictedToLocalhost() bool { - m.restrictMu.RLock() - defer m.restrictMu.RUnlock() - return m.restrictToLocalhost -} - -// setRestrictToLocalhost updates the localhost restriction setting. -func (m *AmpModule) setRestrictToLocalhost(restrict bool) { - m.restrictMu.Lock() - defer m.restrictMu.Unlock() - m.restrictToLocalhost = restrict -} diff --git a/internal/api/modules/amp/amp_test.go b/internal/api/modules/amp/amp_test.go deleted file mode 100644 index 5ca01754a..000000000 --- a/internal/api/modules/amp/amp_test.go +++ /dev/null @@ -1,352 +0,0 @@ -package amp - -import ( - "context" - "net/http/httptest" - "os" - "path/filepath" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/api/modules" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - sdkaccess "github.com/router-for-me/CLIProxyAPI/v7/sdk/access" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" -) - -func TestAmpModule_Name(t *testing.T) { - m := New() - if m.Name() != "amp-routing" { - t.Fatalf("want amp-routing, got %s", m.Name()) - } -} - -func TestAmpModule_New(t *testing.T) { - accessManager := sdkaccess.NewManager() - authMiddleware := func(c *gin.Context) { c.Next() } - - m := NewLegacy(accessManager, authMiddleware) - - if m.accessManager != accessManager { - t.Fatal("accessManager not set") - } - if m.authMiddleware_ == nil { - t.Fatal("authMiddleware not set") - } - if m.enabled { - t.Fatal("enabled should be false initially") - } - if m.proxy != nil { - t.Fatal("proxy should be nil initially") - } -} - -func TestAmpModule_Register_WithUpstream(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Fake upstream to ensure URL is valid - upstream := httptest.NewServer(nil) - defer upstream.Close() - - accessManager := sdkaccess.NewManager() - base := &handlers.BaseAPIHandler{} - - m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) - - cfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamURL: upstream.URL, - UpstreamAPIKey: "test-key", - }, - } - - ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} - if err := m.Register(ctx); err != nil { - t.Fatalf("register error: %v", err) - } - - if !m.enabled { - t.Fatal("module should be enabled with upstream URL") - } - if m.proxy == nil { - t.Fatal("proxy should be initialized") - } - if m.secretSource == nil { - t.Fatal("secretSource should be initialized") - } -} - -func TestAmpModule_Register_WithoutUpstream(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - accessManager := sdkaccess.NewManager() - base := &handlers.BaseAPIHandler{} - - m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) - - cfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamURL: "", // No upstream - }, - } - - ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} - if err := m.Register(ctx); err != nil { - t.Fatalf("register should not error without upstream: %v", err) - } - - if m.enabled { - t.Fatal("module should be disabled without upstream URL") - } - if m.proxy != nil { - t.Fatal("proxy should not be initialized without upstream") - } - - // But provider aliases should still be registered - req := httptest.NewRequest("GET", "/api/provider/openai/models", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code == 404 { - t.Fatal("provider aliases should be registered even without upstream") - } -} - -func TestAmpModule_Register_InvalidUpstream(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - accessManager := sdkaccess.NewManager() - base := &handlers.BaseAPIHandler{} - - m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) - - cfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamURL: "://invalid-url", - }, - } - - ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} - if err := m.Register(ctx); err == nil { - t.Fatal("expected error for invalid upstream URL") - } -} - -func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) { - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { - t.Fatal(err) - } - - m := &AmpModule{enabled: true} - ms := NewMultiSourceSecretWithPath("", p, time.Minute) - m.secretSource = ms - m.lastConfig = &config.AmpCode{ - UpstreamAPIKey: "old-key", - } - - // Warm the cache - if _, err := ms.Get(context.Background()); err != nil { - t.Fatal(err) - } - - if ms.cache == nil { - t.Fatal("expected cache to be set") - } - - // Update config - should invalidate cache - if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil { - t.Fatal(err) - } - - if ms.cache != nil { - t.Fatal("expected cache to be invalidated") - } -} - -func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) { - m := &AmpModule{enabled: false} - - // Should not error or panic when disabled - if err := m.OnConfigUpdated(&config.Config{}); err != nil { - t.Fatalf("unexpected error: %v", err) - } -} - -func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) { - m := &AmpModule{enabled: true} - ms := NewMultiSourceSecret("", 0) - m.secretSource = ms - - // Config update with empty URL - should log warning but not error - cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: ""}} - - if err := m.OnConfigUpdated(cfg); err != nil { - t.Fatalf("unexpected error: %v", err) - } -} - -func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) { - // Test that OnConfigUpdated doesn't panic with StaticSecretSource - m := &AmpModule{enabled: true} - m.secretSource = NewStaticSecretSource("static-key") - - cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://example.com"}} - - // Should not error or panic - if err := m.OnConfigUpdated(cfg); err != nil { - t.Fatalf("unexpected error: %v", err) - } -} - -func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Create module with no auth middleware - m := &AmpModule{authMiddleware_: nil} - - // Get the fallback middleware via getAuthMiddleware - ctx := modules.Context{Engine: r, AuthMiddleware: nil} - middleware := m.getAuthMiddleware(ctx) - - if middleware == nil { - t.Fatal("getAuthMiddleware should return a fallback, not nil") - } - - // Test that it works - called := false - r.GET("/test", middleware, func(c *gin.Context) { - called = true - c.String(200, "ok") - }) - - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if !called { - t.Fatal("fallback middleware should allow requests through") - } -} - -func TestAmpModule_SecretSource_FromConfig(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - upstream := httptest.NewServer(nil) - defer upstream.Close() - - accessManager := sdkaccess.NewManager() - base := &handlers.BaseAPIHandler{} - - m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) - - // Config with explicit API key - cfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamURL: upstream.URL, - UpstreamAPIKey: "config-key", - }, - } - - ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} - if err := m.Register(ctx); err != nil { - t.Fatalf("register error: %v", err) - } - - // Secret source should be MultiSourceSecret with config key - if m.secretSource == nil { - t.Fatal("secretSource should be set") - } - - // Verify it returns the config key - key, err := m.secretSource.Get(context.Background()) - if err != nil { - t.Fatalf("Get error: %v", err) - } - if key != "config-key" { - t.Fatalf("want config-key, got %s", key) - } -} - -func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) { - gin.SetMode(gin.TestMode) - - scenarios := []struct { - name string - configURL string - }{ - {"with_upstream", "http://example.com"}, - {"without_upstream", ""}, - } - - for _, scenario := range scenarios { - t.Run(scenario.name, func(t *testing.T) { - r := gin.New() - accessManager := sdkaccess.NewManager() - base := &handlers.BaseAPIHandler{} - - m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() }) - - cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: scenario.configURL}} - - ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }} - if err := m.Register(ctx); err != nil && scenario.configURL != "" { - t.Fatalf("register error: %v", err) - } - - // Provider aliases should always be available - req := httptest.NewRequest("GET", "/api/provider/openai/models", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code == 404 { - t.Fatal("provider aliases should be registered") - } - }) - } -} - -func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) { - m := &AmpModule{} - - oldCfg := &config.AmpCode{ - UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ - {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}}, - }, - } - newCfg := &config.AmpCode{ - UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ - {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}}, - }, - } - - if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) { - t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates") - } -} - -func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) { - m := &AmpModule{} - - oldCfg := &config.AmpCode{ - UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ - {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}}, - }, - } - newCfg := &config.AmpCode{ - UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ - {UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}}, - }, - } - - if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) { - t.Fatal("expected no change when only whitespace/empty entries differ") - } -} diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go deleted file mode 100644 index 4949ef7a4..000000000 --- a/internal/api/modules/amp/fallback_handlers.go +++ /dev/null @@ -1,343 +0,0 @@ -package amp - -import ( - "bytes" - "io" - "net/http/httputil" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// AmpRouteType represents the type of routing decision made for an Amp request -type AmpRouteType string - -const ( - // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free) - RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" - // RouteTypeModelMapping indicates the request was remapped to another available model (free) - RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" - // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits) - RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" - // RouteTypeNoProvider indicates no provider or fallback available - RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" -) - -// MappedModelContextKey is the Gin context key for passing mapped model names. -const MappedModelContextKey = "mapped_model" - -// logAmpRouting logs the routing decision for an Amp request with structured fields -func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { - fields := log.Fields{ - "component": "amp-routing", - "route_type": string(routeType), - "requested_model": requestedModel, - "path": path, - "timestamp": time.Now().Format(time.RFC3339), - } - - if resolvedModel != "" && resolvedModel != requestedModel { - fields["resolved_model"] = resolvedModel - } - if provider != "" { - fields["provider"] = provider - } - - switch routeType { - case RouteTypeLocalProvider: - fields["cost"] = "free" - fields["source"] = "local_oauth" - log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel) - - case RouteTypeModelMapping: - fields["cost"] = "free" - fields["source"] = "local_oauth" - fields["mapping"] = requestedModel + " -> " + resolvedModel - // model mapping already logged in mapper; avoid duplicate here - - case RouteTypeAmpCredits: - fields["cost"] = "amp_credits" - fields["source"] = "ampcode.com" - fields["model_id"] = requestedModel // Explicit model_id for easy config reference - log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel) - - case RouteTypeNoProvider: - fields["cost"] = "none" - fields["source"] = "error" - fields["model_id"] = requestedModel // Explicit model_id for easy config reference - log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel) - } -} - -// FallbackHandler wraps a standard handler with fallback logic to ampcode.com -// when the model's provider is not available in CLIProxyAPI -type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy - modelMapper ModelMapper - forceModelMappings func() bool -} - -// NewFallbackHandler creates a new fallback handler wrapper -// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) -func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { - return &FallbackHandler{ - getProxy: getProxy, - forceModelMappings: func() bool { return false }, - } -} - -// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support -func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { - if forceModelMappings == nil { - forceModelMappings = func() bool { return false } - } - return &FallbackHandler{ - getProxy: getProxy, - modelMapper: mapper, - forceModelMappings: forceModelMappings, - } -} - -// SetModelMapper sets the model mapper for this handler (allows late binding) -func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { - fh.modelMapper = mapper -} - -// WrapHandler wraps a gin.HandlerFunc with fallback logic -// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com -func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { - return func(c *gin.Context) { - requestPath := c.Request.URL.Path - - // Read the request body to extract the model name - bodyBytes, err := io.ReadAll(c.Request.Body) - if err != nil { - log.Errorf("amp fallback: failed to read request body: %v", err) - handler(c) - return - } - - // Sanitize request body: remove thinking blocks with invalid signatures - // to prevent upstream API 400 errors - bodyBytes = SanitizeAmpRequestBody(bodyBytes) - - // Restore the body for the handler to read - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - - // Try to extract model from request body or URL path (for Gemini) - modelName := extractModelFromRequest(bodyBytes, c) - if modelName == "" { - // Can't determine model, proceed with normal handler - handler(c) - return - } - - // Normalize model (handles dynamic thinking suffixes) - suffixResult := thinking.ParseSuffix(modelName) - normalizedModel := suffixResult.ModelName - thinkingSuffix := "" - if suffixResult.HasSuffix { - thinkingSuffix = "(" + suffixResult.RawSuffix + ")" - } - - resolveMappedModel := func() (string, []string) { - if fh.modelMapper == nil { - return "", nil - } - - mappedModel := fh.modelMapper.MapModel(modelName) - if mappedModel == "" { - mappedModel = fh.modelMapper.MapModel(normalizedModel) - } - mappedModel = strings.TrimSpace(mappedModel) - if mappedModel == "" { - return "", nil - } - - // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target - // already specifies its own thinking suffix. - if thinkingSuffix != "" { - mappedSuffixResult := thinking.ParseSuffix(mappedModel) - if !mappedSuffixResult.HasSuffix { - mappedModel += thinkingSuffix - } - } - - mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName - mappedProviders := util.GetProviderName(mappedBaseModel) - if len(mappedProviders) == 0 { - return "", nil - } - - return mappedModel, mappedProviders - } - - // Track resolved model for logging (may change if mapping is applied) - resolvedModel := normalizedModel - usedMapping := false - var providers []string - - // Check if model mappings should be forced ahead of local API keys - forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() - - if forceMappings { - // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) - // This allows users to route Amp requests to their preferred OAuth providers - if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - - // If no mapping applied, check for local providers - if !usedMapping { - providers = util.GetProviderName(normalizedModel) - } - } else { - // DEFAULT MODE: Check local providers first, then mappings as fallback - providers = util.GetProviderName(normalizedModel) - - if len(providers) == 0 { - // No providers configured - check if we have a model mapping - if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } - } - - // If no providers available, fallback to ampcode.com - if len(providers) == 0 { - proxy := fh.getProxy() - if proxy != nil { - // Log: Forwarding to ampcode.com (uses Amp credits) - logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) - - // Restore body again for the proxy - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - - // Forward to ampcode.com - proxy.ServeHTTP(c.Writer, c.Request) - return - } - - // No proxy available, let the normal handler return the error - logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) - } - - // Log the routing decision - providerName := "" - if len(providers) > 0 { - providerName = providers[0] - } - - if usedMapping { - // Log: Model was mapped to another model - log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) - logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) - rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes) - rewriter.suppressThinking = true - c.Writer = rewriter - // Filter Anthropic-Beta header only for local handling paths - filterAntropicBetaHeader(c) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - handler(c) - rewriter.Flush() - log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) - } else if len(providers) > 0 { - // Log: Using local provider (free) - logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) - // Wrap with ResponseRewriter for local providers too, because upstream - // proxies (e.g. NewAPI) may return a different model name and lack - // Amp-required fields like thinking.signature. - rewriter := NewResponseRewriterForRequest(c.Writer, modelName, bodyBytes) - rewriter.suppressThinking = providerName != "claude" - c.Writer = rewriter - // Filter Anthropic-Beta header only for local handling paths - filterAntropicBetaHeader(c) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - handler(c) - rewriter.Flush() - } else { - // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - handler(c) - } - } -} - -// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription -// This is needed when using local providers (bypassing the Amp proxy) -func filterAntropicBetaHeader(c *gin.Context) { - if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { - if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" { - c.Request.Header.Set("Anthropic-Beta", filtered) - } else { - c.Request.Header.Del("Anthropic-Beta") - } - } -} - -// rewriteModelInRequest replaces the model name in a JSON request body -func rewriteModelInRequest(body []byte, newModel string) []byte { - if !gjson.GetBytes(body, "model").Exists() { - return body - } - result, err := sjson.SetBytes(body, "model", newModel) - if err != nil { - log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err) - return body - } - return result -} - -// extractModelFromRequest attempts to extract the model name from various request formats -func extractModelFromRequest(body []byte, c *gin.Context) string { - // First try to parse from JSON body (OpenAI, Claude, etc.) - // Check common model field names - if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { - return result.String() - } - - // For Gemini requests, model is in the URL path - // Standard format: /models/{model}:generateContent -> :action parameter - if action := c.Param("action"); action != "" { - // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro") - parts := strings.Split(action, ":") - if len(parts) > 0 && parts[0] != "" { - return parts[0] - } - } - - // AMP CLI format: /publishers/google/models/{model}:method -> *path parameter - // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent - if path := c.Param("path"); path != "" { - // Look for /models/{model}:method pattern - if idx := strings.Index(path, "/models/"); idx >= 0 { - modelPart := path[idx+8:] // Skip "/models/" - // Split by colon to get model name - if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { - return modelPart[:colonIdx] - } - } - } - - return "" -} diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go deleted file mode 100644 index 7e6f10a2f..000000000 --- a/internal/api/modules/amp/fallback_handlers_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package amp - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "net/http/httputil" - "testing" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" -) - -func TestFallbackHandler_RequestToolCasing_RewritesStreamingResponse(t *testing.T) { - gin.SetMode(gin.TestMode) - - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-amp-tool-casing", "codex", []*registry.ModelInfo{ - {ID: "test/gpt-tool-casing", OwnedBy: "openai", Type: "codex"}, - }) - defer reg.UnregisterClient("test-client-amp-tool-casing") - - fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, nil, nil) - handler := func(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - _, _ = c.Writer.Write([]byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n")) - } - - r := gin.New() - r.POST("/messages", fallback.WrapHandler(handler)) - - reqBody := []byte(`{"model":"test/gpt-tool-casing","tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) - req := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(reqBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected status 200, got %d", w.Code) - } - if !bytes.Contains(w.Body.Bytes(), []byte(`"name":"Glob"`)) { - t.Fatalf("expected streaming response to restore glob->Glob, got %s", w.Body.String()) - } -} - -func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { - gin.SetMode(gin.TestMode) - - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ - {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, - }) - defer reg.UnregisterClient("test-client-amp-fallback") - - mapper := NewModelMapper([]config.AmpModelMapping{ - {From: "gpt-5.2", To: "test/gpt-5.2"}, - }) - - fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) - - handler := func(c *gin.Context) { - var req struct { - Model string `json:"model"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "model": req.Model, - "seen_model": req.Model, - }) - } - - r := gin.New() - r.POST("/chat/completions", fallback.WrapHandler(handler)) - - reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) - req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected status 200, got %d", w.Code) - } - - var resp struct { - Model string `json:"model"` - SeenModel string `json:"seen_model"` - } - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("Failed to parse response JSON: %v", err) - } - - if resp.Model != "gpt-5.2(xhigh)" { - t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) - } - if resp.SeenModel != "test/gpt-5.2(xhigh)" { - t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) - } -} diff --git a/internal/api/modules/amp/gemini_bridge.go b/internal/api/modules/amp/gemini_bridge.go deleted file mode 100644 index d6ad8f797..000000000 --- a/internal/api/modules/amp/gemini_bridge.go +++ /dev/null @@ -1,59 +0,0 @@ -package amp - -import ( - "strings" - - "github.com/gin-gonic/gin" -) - -// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths -// to our standard Gemini handler by rewriting the request context. -// -// AMP CLI format: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent -// Standard format: /models/gemini-3-pro-preview:streamGenerateContent -// -// This extracts the model+method from the AMP path and sets it as the :action parameter -// so the standard Gemini handler can process it. -// -// The handler parameter should be a Gemini-compatible handler that expects the :action param. -func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc { - return func(c *gin.Context) { - // Get the full path from the catch-all parameter - path := c.Param("path") - - // Extract model:method from AMP CLI path format - // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent - const modelsPrefix = "/models/" - if idx := strings.Index(path, modelsPrefix); idx >= 0 { - // Extract everything after modelsPrefix - actionPart := path[idx+len(modelsPrefix):] - - // Check if model was mapped by FallbackHandler - if mappedModel, exists := c.Get(MappedModelContextKey); exists { - if strModel, ok := mappedModel.(string); ok && strModel != "" { - // Replace the model part in the action - // actionPart is like "model-name:method" - if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 { - method := actionPart[colonIdx:] // ":method" - actionPart = strModel + method - } - } - } - - // Set this as the :action parameter that the Gemini handler expects - c.Params = append(c.Params, gin.Param{ - Key: "action", - Value: actionPart, - }) - - // Call the handler - handler(c) - return - } - - // If we can't parse the path, return 400 - c.JSON(400, gin.H{ - "error": "Invalid Gemini API path format", - }) - } -} diff --git a/internal/api/modules/amp/gemini_bridge_test.go b/internal/api/modules/amp/gemini_bridge_test.go deleted file mode 100644 index 347456c38..000000000 --- a/internal/api/modules/amp/gemini_bridge_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package amp - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" -) - -func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - path string - mappedModel string // empty string means no mapping - expectedAction string - }{ - { - name: "no_mapping_uses_url_model", - path: "/publishers/google/models/gemini-pro:generateContent", - mappedModel: "", - expectedAction: "gemini-pro:generateContent", - }, - { - name: "mapped_model_replaces_url_model", - path: "/publishers/google/models/gemini-exp:generateContent", - mappedModel: "gemini-2.0-flash", - expectedAction: "gemini-2.0-flash:generateContent", - }, - { - name: "mapping_preserves_method", - path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent", - mappedModel: "gemini-flash", - expectedAction: "gemini-flash:streamGenerateContent", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var capturedAction string - - mockGeminiHandler := func(c *gin.Context) { - capturedAction = c.Param("action") - c.JSON(http.StatusOK, gin.H{"captured": capturedAction}) - } - - // Use the actual createGeminiBridgeHandler function - bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler) - - r := gin.New() - if tt.mappedModel != "" { - r.Use(func(c *gin.Context) { - c.Set(MappedModelContextKey, tt.mappedModel) - c.Next() - }) - } - r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) - - req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected status 200, got %d", w.Code) - } - if capturedAction != tt.expectedAction { - t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction) - } - }) - } -} - -func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) { - gin.SetMode(gin.TestMode) - - mockHandler := func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - } - bridgeHandler := createGeminiBridgeHandler(mockHandler) - - r := gin.New() - r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) - - req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("Expected status 400 for invalid path, got %d", w.Code) - } -} diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go deleted file mode 100644 index 2b68866ed..000000000 --- a/internal/api/modules/amp/model_mapping.go +++ /dev/null @@ -1,171 +0,0 @@ -// Package amp provides model mapping functionality for routing Amp CLI requests -// to alternative models when the requested model is not available locally. -package amp - -import ( - "regexp" - "strings" - "sync" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - log "github.com/sirupsen/logrus" -) - -// ModelMapper provides model name mapping/aliasing for Amp CLI requests. -// When an Amp request comes in for a model that isn't available locally, -// this mapper can redirect it to an alternative model that IS available. -type ModelMapper interface { - // MapModel returns the target model name if a mapping exists and the target - // model has available providers. Returns empty string if no mapping applies. - MapModel(requestedModel string) string - - // UpdateMappings refreshes the mapping configuration (for hot-reload). - UpdateMappings(mappings []config.AmpModelMapping) -} - -// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. -type DefaultModelMapper struct { - mu sync.RWMutex - mappings map[string]string // exact: from -> to (normalized lowercase keys) - regexps []regexMapping // regex rules evaluated in order -} - -// NewModelMapper creates a new model mapper with the given initial mappings. -func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { - m := &DefaultModelMapper{ - mappings: make(map[string]string), - regexps: nil, - } - m.UpdateMappings(mappings) - return m -} - -// MapModel checks if a mapping exists for the requested model and if the -// target model has available local providers. Returns the mapped model name -// or empty string if no valid mapping exists. -// -// If the requested model contains a thinking suffix (e.g., "g25p(8192)"), -// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)"). -// However, if the mapping target already contains a suffix, the config suffix -// takes priority over the user's suffix. -func (m *DefaultModelMapper) MapModel(requestedModel string) string { - if requestedModel == "" { - return "" - } - - m.mu.RLock() - defer m.mu.RUnlock() - - // Extract thinking suffix from requested model using ParseSuffix - requestResult := thinking.ParseSuffix(requestedModel) - baseModel := requestResult.ModelName - - // Normalize the base model for lookup (case-insensitive) - normalizedBase := strings.ToLower(strings.TrimSpace(baseModel)) - - // Check for direct mapping using base model name - targetModel, exists := m.mappings[normalizedBase] - if !exists { - // Try regex mappings in order using base model only - // (suffix is handled separately via ParseSuffix) - for _, rm := range m.regexps { - if rm.re.MatchString(baseModel) { - targetModel = rm.to - exists = true - break - } - } - if !exists { - return "" - } - } - - // Check if target model already has a thinking suffix (config priority) - targetResult := thinking.ParseSuffix(targetModel) - - // Verify target model has available providers (use base model for lookup) - providers := util.GetProviderName(targetResult.ModelName) - if len(providers) == 0 { - log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) - return "" - } - - // Suffix handling: config suffix takes priority, otherwise preserve user suffix - if targetResult.HasSuffix { - // Config's "to" already contains a suffix - use it as-is (config priority) - return targetModel - } - - // Preserve user's thinking suffix on the mapped model - // (skip empty suffixes to avoid returning "model()") - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return targetModel + "(" + requestResult.RawSuffix + ")" - } - - // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go - return targetModel -} - -// UpdateMappings refreshes the mapping configuration from config. -// This is called during initialization and on config hot-reload. -func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { - m.mu.Lock() - defer m.mu.Unlock() - - // Clear and rebuild mappings - m.mappings = make(map[string]string, len(mappings)) - m.regexps = make([]regexMapping, 0, len(mappings)) - - for _, mapping := range mappings { - from := strings.TrimSpace(mapping.From) - to := strings.TrimSpace(mapping.To) - - if from == "" || to == "" { - log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to) - continue - } - - if mapping.Regex { - // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups - pattern := "(?i)" + from - re, err := regexp.Compile(pattern) - if err != nil { - log.Warnf("amp model mapping: invalid regex %q: %v", from, err) - continue - } - m.regexps = append(m.regexps, regexMapping{re: re, to: to}) - log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to) - } else { - // Store with normalized lowercase key for case-insensitive lookup - normalizedFrom := strings.ToLower(from) - m.mappings[normalizedFrom] = to - log.Debugf("amp model mapping registered: %s -> %s", from, to) - } - } - - if len(m.mappings) > 0 { - log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) - } - if n := len(m.regexps); n > 0 { - log.Infof("amp model mapping: loaded %d regex mapping(s)", n) - } -} - -// GetMappings returns a copy of current mappings (for debugging/status). -func (m *DefaultModelMapper) GetMappings() map[string]string { - m.mu.RLock() - defer m.mu.RUnlock() - - result := make(map[string]string, len(m.mappings)) - for k, v := range m.mappings { - result[k] = v - } - return result -} - -type regexMapping struct { - re *regexp.Regexp - to string -} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go deleted file mode 100644 index dcfb07ee5..000000000 --- a/internal/api/modules/amp/model_mapping_test.go +++ /dev/null @@ -1,375 +0,0 @@ -package amp - -import ( - "testing" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" -) - -func TestNewModelMapper(t *testing.T) { - mappings := []config.AmpModelMapping{ - {From: "claude-opus-4.5", To: "claude-sonnet-4"}, - {From: "gpt-5", To: "gemini-2.5-pro"}, - } - - mapper := NewModelMapper(mappings) - if mapper == nil { - t.Fatal("Expected non-nil mapper") - } - - result := mapper.GetMappings() - if len(result) != 2 { - t.Errorf("Expected 2 mappings, got %d", len(result)) - } -} - -func TestNewModelMapper_Empty(t *testing.T) { - mapper := NewModelMapper(nil) - if mapper == nil { - t.Fatal("Expected non-nil mapper") - } - - result := mapper.GetMappings() - if len(result) != 0 { - t.Errorf("Expected 0 mappings, got %d", len(result)) - } -} - -func TestModelMapper_MapModel_NoProvider(t *testing.T) { - mappings := []config.AmpModelMapping{ - {From: "claude-opus-4.5", To: "claude-sonnet-4"}, - } - - mapper := NewModelMapper(mappings) - - // Without a registered provider for the target, mapping should return empty - result := mapper.MapModel("claude-opus-4.5") - if result != "" { - t.Errorf("Expected empty result when target has no provider, got %s", result) - } -} - -func TestModelMapper_MapModel_WithProvider(t *testing.T) { - // Register a mock provider for the target model - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{ - {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, - }) - defer reg.UnregisterClient("test-client") - - mappings := []config.AmpModelMapping{ - {From: "claude-opus-4.5", To: "claude-sonnet-4"}, - } - - mapper := NewModelMapper(mappings) - - // With a registered provider, mapping should work - result := mapper.MapModel("claude-opus-4.5") - if result != "claude-sonnet-4" { - t.Errorf("Expected claude-sonnet-4, got %s", result) - } -} - -func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) { - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{ - {ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"}, - }) - defer reg.UnregisterClient("test-client-thinking") - - mappings := []config.AmpModelMapping{ - {From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"}, - } - - mapper := NewModelMapper(mappings) - - result := mapper.MapModel("gpt-5.2-alias") - if result != "gpt-5.2(xhigh)" { - t.Errorf("Expected gpt-5.2(xhigh), got %s", result) - } -} - -func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ - {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, - }) - defer reg.UnregisterClient("test-client2") - - mappings := []config.AmpModelMapping{ - {From: "Claude-Opus-4.5", To: "claude-sonnet-4"}, - } - - mapper := NewModelMapper(mappings) - - // Should match case-insensitively - result := mapper.MapModel("claude-opus-4.5") - if result != "claude-sonnet-4" { - t.Errorf("Expected claude-sonnet-4, got %s", result) - } -} - -func TestModelMapper_MapModel_NotFound(t *testing.T) { - mappings := []config.AmpModelMapping{ - {From: "claude-opus-4.5", To: "claude-sonnet-4"}, - } - - mapper := NewModelMapper(mappings) - - // Unknown model should return empty - result := mapper.MapModel("unknown-model") - if result != "" { - t.Errorf("Expected empty for unknown model, got %s", result) - } -} - -func TestModelMapper_MapModel_EmptyInput(t *testing.T) { - mappings := []config.AmpModelMapping{ - {From: "claude-opus-4.5", To: "claude-sonnet-4"}, - } - - mapper := NewModelMapper(mappings) - - result := mapper.MapModel("") - if result != "" { - t.Errorf("Expected empty for empty input, got %s", result) - } -} - -func TestModelMapper_UpdateMappings(t *testing.T) { - mapper := NewModelMapper(nil) - - // Initially empty - if len(mapper.GetMappings()) != 0 { - t.Error("Expected 0 initial mappings") - } - - // Update with new mappings - mapper.UpdateMappings([]config.AmpModelMapping{ - {From: "model-a", To: "model-b"}, - {From: "model-c", To: "model-d"}, - }) - - result := mapper.GetMappings() - if len(result) != 2 { - t.Errorf("Expected 2 mappings after update, got %d", len(result)) - } - - // Update again should replace, not append - mapper.UpdateMappings([]config.AmpModelMapping{ - {From: "model-x", To: "model-y"}, - }) - - result = mapper.GetMappings() - if len(result) != 1 { - t.Errorf("Expected 1 mapping after second update, got %d", len(result)) - } -} - -func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { - mapper := NewModelMapper(nil) - - mapper.UpdateMappings([]config.AmpModelMapping{ - {From: "", To: "model-b"}, // Invalid: empty from - {From: "model-a", To: ""}, // Invalid: empty to - {From: " ", To: "model-b"}, // Invalid: whitespace from - {From: "model-c", To: "model-d"}, // Valid - }) - - result := mapper.GetMappings() - if len(result) != 1 { - t.Errorf("Expected 1 valid mapping, got %d", len(result)) - } -} - -func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { - mappings := []config.AmpModelMapping{ - {From: "model-a", To: "model-b"}, - } - - mapper := NewModelMapper(mappings) - - // Get mappings and modify the returned map - result := mapper.GetMappings() - result["new-key"] = "new-value" - - // Original should be unchanged - original := mapper.GetMappings() - if len(original) != 1 { - t.Errorf("Expected original to have 1 mapping, got %d", len(original)) - } - if _, exists := original["new-key"]; exists { - t.Error("Original map was modified") - } -} - -func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) { - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{ - {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, - }) - defer reg.UnregisterClient("test-client-regex-1") - - mappings := []config.AmpModelMapping{ - {From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true}, - } - - mapper := NewModelMapper(mappings) - - // Incoming model has reasoning suffix, regex matches base, suffix is preserved - result := mapper.MapModel("gpt-5(high)") - if result != "gemini-2.5-pro(high)" { - t.Errorf("Expected gemini-2.5-pro(high), got %s", result) - } -} - -func TestModelMapper_Regex_ExactPrecedence(t *testing.T) { - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{ - {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, - }) - reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{ - {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, - }) - defer reg.UnregisterClient("test-client-regex-2") - defer reg.UnregisterClient("test-client-regex-3") - - mappings := []config.AmpModelMapping{ - {From: "gpt-5", To: "claude-sonnet-4"}, // exact - {From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex - } - - mapper := NewModelMapper(mappings) - - // Exact match should win over regex - result := mapper.MapModel("gpt-5") - if result != "claude-sonnet-4" { - t.Errorf("Expected claude-sonnet-4, got %s", result) - } -} - -func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) { - // Invalid regex should be skipped and not cause panic - mappings := []config.AmpModelMapping{ - {From: "(", To: "target", Regex: true}, - } - - mapper := NewModelMapper(mappings) - - result := mapper.MapModel("anything") - if result != "" { - t.Errorf("Expected empty result due to invalid regex, got %s", result) - } -} - -func TestModelMapper_Regex_CaseInsensitive(t *testing.T) { - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{ - {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, - }) - defer reg.UnregisterClient("test-client-regex-4") - - mappings := []config.AmpModelMapping{ - {From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true}, - } - - mapper := NewModelMapper(mappings) - - result := mapper.MapModel("claude-opus-4.5") - if result != "claude-sonnet-4" { - t.Errorf("Expected claude-sonnet-4, got %s", result) - } -} - -func TestModelMapper_SuffixPreservation(t *testing.T) { - reg := registry.GetGlobalRegistry() - - // Register test models - reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{ - {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, - }) - reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{ - {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, - }) - defer reg.UnregisterClient("test-client-suffix") - defer reg.UnregisterClient("test-client-suffix-2") - - tests := []struct { - name string - mappings []config.AmpModelMapping - input string - want string - }{ - { - name: "numeric suffix preserved", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p(8192)", - want: "gemini-2.5-pro(8192)", - }, - { - name: "level suffix preserved", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p(high)", - want: "gemini-2.5-pro(high)", - }, - { - name: "no suffix unchanged", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p", - want: "gemini-2.5-pro", - }, - { - name: "config suffix takes priority", - mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}}, - input: "alias(high)", - want: "gemini-2.5-pro(medium)", - }, - { - name: "regex with suffix preserved", - mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}}, - input: "g25p(8192)", - want: "gemini-2.5-pro(8192)", - }, - { - name: "auto suffix preserved", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p(auto)", - want: "gemini-2.5-pro(auto)", - }, - { - name: "none suffix preserved", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p(none)", - want: "gemini-2.5-pro(none)", - }, - { - name: "case insensitive base lookup with suffix", - mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}}, - input: "g25p(high)", - want: "gemini-2.5-pro(high)", - }, - { - name: "empty suffix filtered out", - mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}}, - input: "g25p()", - want: "gemini-2.5-pro", - }, - { - name: "incomplete suffix treated as no suffix", - mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}}, - input: "g25p(high", - want: "gemini-2.5-pro", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mapper := NewModelMapper(tt.mappings) - got := mapper.MapModel(tt.input) - if got != tt.want { - t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go deleted file mode 100644 index 54f4b734b..000000000 --- a/internal/api/modules/amp/proxy.go +++ /dev/null @@ -1,240 +0,0 @@ -package amp - -import ( - "bytes" - "compress/gzip" - "context" - "errors" - "fmt" - "io" - "net/http" - "net/http/httputil" - "net/url" - "strconv" - "strings" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - log "github.com/sirupsen/logrus" -) - -func removeQueryValuesMatching(req *http.Request, key string, match string) { - if req == nil || req.URL == nil || match == "" { - return - } - - q := req.URL.Query() - values, ok := q[key] - if !ok || len(values) == 0 { - return - } - - kept := make([]string, 0, len(values)) - for _, v := range values { - if v == match { - continue - } - kept = append(kept, v) - } - - if len(kept) == 0 { - q.Del(key) - } else { - q[key] = kept - } - req.URL.RawQuery = q.Encode() -} - -// readCloser wraps a reader and forwards Close to a separate closer. -// Used to restore peeked bytes while preserving upstream body Close behavior. -type readCloser struct { - r io.Reader - c io.Closer -} - -func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) } -func (rc *readCloser) Close() error { return rc.c.Close() } - -// createReverseProxy creates a reverse proxy handler for Amp upstream -// with automatic gzip decompression via ModifyResponse -func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) { - parsed, err := url.Parse(upstreamURL) - if err != nil { - return nil, fmt.Errorf("invalid amp upstream url: %w", err) - } - - proxy := httputil.NewSingleHostReverseProxy(parsed) - originalDirector := proxy.Director - - // Modify outgoing requests to inject API key and fix routing - proxy.Director = func(req *http.Request) { - originalDirector(req) - req.Host = parsed.Host - - // Remove client's Authorization header - it was only used for CLI Proxy API authentication - // We will set our own Authorization using the configured upstream-api-key - req.Header.Del("Authorization") - req.Header.Del("X-Api-Key") - req.Header.Del("X-Goog-Api-Key") - - // Remove proxy, client identity, and browser fingerprint headers - misc.ScrubProxyAndFingerprintHeaders(req) - - // Remove query-based credentials if they match the authenticated client API key. - // This prevents leaking client auth material to the Amp upstream while avoiding - // breaking unrelated upstream query parameters. - clientKey := getClientAPIKeyFromContext(req.Context()) - removeQueryValuesMatching(req, "key", clientKey) - removeQueryValuesMatching(req, "auth_token", clientKey) - - // Preserve correlation headers for debugging - if req.Header.Get("X-Request-ID") == "" { - // Could generate one here if needed - } - - // Note: We do NOT filter Anthropic-Beta headers in the proxy path - // Users going through ampcode.com proxy are paying for the service and should get all features - // including 1M context window (context-1m-2025-08-07) - - // Inject API key from secret source (only uses upstream-api-key from config) - if key, err := secretSource.Get(req.Context()); err == nil && key != "" { - req.Header.Set("X-Api-Key", key) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) - } else if err != nil { - log.Warnf("amp secret source error (continuing without auth): %v", err) - } - } - - // Modify incoming responses to handle gzip without Content-Encoding - // This addresses the same issue as inline handler gzip handling, but at the proxy level - proxy.ModifyResponse = func(resp *http.Response) error { - // Skip if already marked as gzip (Content-Encoding set) - if resp.Header.Get("Content-Encoding") != "" { - return nil - } - - // Skip streaming responses (SSE, chunked) - if isStreamingResponse(resp) { - return nil - } - - // Save reference to original upstream body for proper cleanup - originalBody := resp.Body - - // Peek at first 2 bytes to detect gzip magic bytes - header := make([]byte, 2) - n, _ := io.ReadFull(originalBody, header) - - // Check for gzip magic bytes (0x1f 0x8b) - // If n < 2, we didn't get enough bytes, so it's not gzip - if n >= 2 && header[0] == 0x1f && header[1] == 0x8b { - // It's gzip - read the rest of the body - rest, err := io.ReadAll(originalBody) - if err != nil { - // Restore what we read and return original body (preserve Close behavior) - resp.Body = &readCloser{ - r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), - c: originalBody, - } - return nil - } - - // Reconstruct complete gzipped data - gzippedData := append(header[:n], rest...) - - // Decompress - gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData)) - if err != nil { - log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err) - // Close original body and return in-memory copy - _ = originalBody.Close() - resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) - return nil - } - - decompressed, err := io.ReadAll(gzipReader) - _ = gzipReader.Close() - if err != nil { - log.Warnf("amp proxy: gzip decompress error: %v", err) - // Close original body and return in-memory copy - _ = originalBody.Close() - resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) - return nil - } - - // Close original body since we're replacing with in-memory decompressed content - _ = originalBody.Close() - - // Replace body with decompressed content - resp.Body = io.NopCloser(bytes.NewReader(decompressed)) - resp.ContentLength = int64(len(decompressed)) - - // Update headers to reflect decompressed state - resp.Header.Del("Content-Encoding") // No longer compressed - resp.Header.Del("Content-Length") // Remove stale compressed length - resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length - - log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed)) - } else { - // Not gzip - restore peeked bytes while preserving Close behavior - // Handle edge cases: n might be 0, 1, or 2 depending on EOF - resp.Body = &readCloser{ - r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), - c: originalBody, - } - } - - return nil - } - - // Error handler for proxy failures - proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - // Client-side cancellations are common during polling; suppress logging in this case - if errors.Is(err, context.Canceled) { - return - } - log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(http.StatusBadGateway) - _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) - } - - return proxy, nil -} - -// isStreamingResponse detects if the response is streaming (SSE only) -// Note: We only treat text/event-stream as streaming. Chunked transfer encoding -// is a transport-level detail and doesn't mean we can't decompress the full response. -// Many JSON APIs use chunked encoding for normal responses. -func isStreamingResponse(resp *http.Response) bool { - contentType := resp.Header.Get("Content-Type") - - // Only Server-Sent Events are true streaming responses - if strings.Contains(contentType, "text/event-stream") { - return true - } - - return false -} - -// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc -func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc { - return func(c *gin.Context) { - proxy.ServeHTTP(c.Writer, c.Request) - } -} - -// filterBetaFeatures removes a specific beta feature from comma-separated list -func filterBetaFeatures(header, featureToRemove string) string { - features := strings.Split(header, ",") - filtered := make([]string, 0, len(features)) - - for _, feature := range features { - trimmed := strings.TrimSpace(feature) - if trimmed != "" && trimmed != featureToRemove { - filtered = append(filtered, trimmed) - } - } - - return strings.Join(filtered, ",") -} diff --git a/internal/api/modules/amp/proxy_test.go b/internal/api/modules/amp/proxy_test.go deleted file mode 100644 index 2852efde3..000000000 --- a/internal/api/modules/amp/proxy_test.go +++ /dev/null @@ -1,681 +0,0 @@ -package amp - -import ( - "bytes" - "compress/gzip" - "context" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" -) - -// Helper: compress data with gzip -func gzipBytes(b []byte) []byte { - var buf bytes.Buffer - zw := gzip.NewWriter(&buf) - zw.Write(b) - zw.Close() - return buf.Bytes() -} - -// Helper: create a mock http.Response -func mkResp(status int, hdr http.Header, body []byte) *http.Response { - if hdr == nil { - hdr = http.Header{} - } - return &http.Response{ - StatusCode: status, - Header: hdr, - Body: io.NopCloser(bytes.NewReader(body)), - ContentLength: int64(len(body)), - } -} - -func TestCreateReverseProxy_ValidURL(t *testing.T) { - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key")) - if err != nil { - t.Fatalf("expected no error, got: %v", err) - } - if proxy == nil { - t.Fatal("expected proxy to be created") - } -} - -func TestCreateReverseProxy_InvalidURL(t *testing.T) { - _, err := createReverseProxy("://invalid", NewStaticSecretSource("key")) - if err == nil { - t.Fatal("expected error for invalid URL") - } -} - -func TestModifyResponse_GzipScenarios(t *testing.T) { - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) - if err != nil { - t.Fatal(err) - } - - goodJSON := []byte(`{"ok":true}`) - good := gzipBytes(goodJSON) - truncated := good[:10] - corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...) - - cases := []struct { - name string - header http.Header - body []byte - status int - wantBody []byte - wantCE string - }{ - { - name: "decompresses_valid_gzip_no_header", - header: http.Header{}, - body: good, - status: 200, - wantBody: goodJSON, - wantCE: "", - }, - { - name: "skips_when_ce_present", - header: http.Header{"Content-Encoding": []string{"gzip"}}, - body: good, - status: 200, - wantBody: good, - wantCE: "gzip", - }, - { - name: "passes_truncated_unchanged", - header: http.Header{}, - body: truncated, - status: 200, - wantBody: truncated, - wantCE: "", - }, - { - name: "passes_corrupted_unchanged", - header: http.Header{}, - body: corrupted, - status: 200, - wantBody: corrupted, - wantCE: "", - }, - { - name: "non_gzip_unchanged", - header: http.Header{}, - body: []byte("plain"), - status: 200, - wantBody: []byte("plain"), - wantCE: "", - }, - { - name: "empty_body", - header: http.Header{}, - body: []byte{}, - status: 200, - wantBody: []byte{}, - wantCE: "", - }, - { - name: "single_byte_body", - header: http.Header{}, - body: []byte{0x1f}, - status: 200, - wantBody: []byte{0x1f}, - wantCE: "", - }, - { - name: "decompresses_non_2xx_status_when_gzip_detected", - header: http.Header{}, - body: good, - status: 404, - wantBody: goodJSON, - wantCE: "", - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - resp := mkResp(tc.status, tc.header, tc.body) - if err := proxy.ModifyResponse(resp); err != nil { - t.Fatalf("ModifyResponse error: %v", err) - } - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadAll error: %v", err) - } - if !bytes.Equal(got, tc.wantBody) { - t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got) - } - if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE { - t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce) - } - }) - } -} - -func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) { - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) - if err != nil { - t.Fatal(err) - } - - goodJSON := []byte(`{"message":"test response"}`) - gzipped := gzipBytes(goodJSON) - - // Simulate upstream response with gzip body AND Content-Length header - // (this is the scenario the bot flagged - stale Content-Length after decompression) - resp := mkResp(200, http.Header{ - "Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size - }, gzipped) - - if err := proxy.ModifyResponse(resp); err != nil { - t.Fatalf("ModifyResponse error: %v", err) - } - - // Verify body is decompressed - got, _ := io.ReadAll(resp.Body) - if !bytes.Equal(got, goodJSON) { - t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON) - } - - // Verify Content-Length header is updated to decompressed size - wantCL := fmt.Sprintf("%d", len(goodJSON)) - gotCL := resp.Header.Get("Content-Length") - if gotCL != wantCL { - t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL) - } - - // Verify struct field also matches - if resp.ContentLength != int64(len(goodJSON)) { - t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength) - } -} - -func TestModifyResponse_SkipsStreamingResponses(t *testing.T) { - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) - if err != nil { - t.Fatal(err) - } - - goodJSON := []byte(`{"ok":true}`) - gzipped := gzipBytes(goodJSON) - - t.Run("sse_skips_decompression", func(t *testing.T) { - resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped) - if err := proxy.ModifyResponse(resp); err != nil { - t.Fatalf("ModifyResponse error: %v", err) - } - // SSE should NOT be decompressed - got, _ := io.ReadAll(resp.Body) - if !bytes.Equal(got, gzipped) { - t.Fatal("SSE response should not be decompressed") - } - }) -} - -func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) { - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) - if err != nil { - t.Fatal(err) - } - - goodJSON := []byte(`{"ok":true}`) - gzipped := gzipBytes(goodJSON) - - t.Run("chunked_json_decompresses", func(t *testing.T) { - // Chunked JSON responses (like thread APIs) should be decompressed - resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped) - if err := proxy.ModifyResponse(resp); err != nil { - t.Fatalf("ModifyResponse error: %v", err) - } - // Should decompress because it's not SSE - got, _ := io.ReadAll(resp.Body) - if !bytes.Equal(got, goodJSON) { - t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON) - } - }) -} - -func TestReverseProxy_InjectsHeaders(t *testing.T) { - gotHeaders := make(chan http.Header, 1) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders <- r.Header.Clone() - w.WriteHeader(200) - w.Write([]byte(`ok`)) - })) - defer upstream.Close() - - proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxy.ServeHTTP(w, r) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - hdr := <-gotHeaders - if hdr.Get("X-Api-Key") != "secret" { - t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) - } - if hdr.Get("Authorization") != "Bearer secret" { - t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) - } -} - -func TestReverseProxy_EmptySecret(t *testing.T) { - gotHeaders := make(chan http.Header, 1) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders <- r.Header.Clone() - w.WriteHeader(200) - w.Write([]byte(`ok`)) - })) - defer upstream.Close() - - proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxy.ServeHTTP(w, r) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - hdr := <-gotHeaders - // Should NOT inject headers when secret is empty - if hdr.Get("X-Api-Key") != "" { - t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key")) - } - if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " { - t.Fatalf("Authorization should not be set, got: %q", authVal) - } -} - -func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) { - type captured struct { - headers http.Header - query string - } - got := make(chan captured, 1) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery} - w.WriteHeader(200) - w.Write([]byte(`ok`)) - })) - defer upstream.Close() - - proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simulate clientAPIKeyMiddleware injection (per-request) - ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key") - proxy.ServeHTTP(w, r.WithContext(ctx)) - })) - defer srv.Close() - - req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Authorization", "Bearer client-key") - req.Header.Set("X-Api-Key", "client-key") - req.Header.Set("X-Goog-Api-Key", "client-key") - - res, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - c := <-got - - // These are client-provided credentials and must not reach the upstream. - if v := c.headers.Get("X-Goog-Api-Key"); v != "" { - t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v) - } - - // We inject upstream Authorization/X-Api-Key, so the client auth must not survive. - if v := c.headers.Get("Authorization"); v != "Bearer upstream" { - t.Fatalf("Authorization should be upstream-injected, got: %q", v) - } - if v := c.headers.Get("X-Api-Key"); v != "upstream" { - t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v) - } - - // Query-based credentials should be stripped only when they match the authenticated client key. - // Should keep unrelated values and parameters. - if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") { - t.Fatalf("query credentials should be stripped, got raw query: %q", c.query) - } - if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") { - t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query) - } -} - -func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) { - gotHeaders := make(chan http.Header, 1) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders <- r.Header.Clone() - w.WriteHeader(200) - w.Write([]byte(`ok`)) - })) - defer upstream.Close() - - defaultSource := NewStaticSecretSource("default") - mapped := NewMappedSecretSource(defaultSource) - mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ - { - UpstreamAPIKey: "u1", - APIKeys: []string{"k1"}, - }, - }) - - proxy, err := createReverseProxy(upstream.URL, mapped) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simulate clientAPIKeyMiddleware injection (per-request) - ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1") - proxy.ServeHTTP(w, r.WithContext(ctx)) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - hdr := <-gotHeaders - if hdr.Get("X-Api-Key") != "u1" { - t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) - } - if hdr.Get("Authorization") != "Bearer u1" { - t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) - } -} - -func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) { - gotHeaders := make(chan http.Header, 1) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotHeaders <- r.Header.Clone() - w.WriteHeader(200) - w.Write([]byte(`ok`)) - })) - defer upstream.Close() - - defaultSource := NewStaticSecretSource("default") - mapped := NewMappedSecretSource(defaultSource) - mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ - { - UpstreamAPIKey: "u1", - APIKeys: []string{"k1"}, - }, - }) - - proxy, err := createReverseProxy(upstream.URL, mapped) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2") - proxy.ServeHTTP(w, r.WithContext(ctx)) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - res.Body.Close() - - hdr := <-gotHeaders - if hdr.Get("X-Api-Key") != "default" { - t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key")) - } - if hdr.Get("Authorization") != "Bearer default" { - t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization")) - } -} - -func TestReverseProxy_ErrorHandler(t *testing.T) { - // Point proxy to a non-routable address to trigger error - proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxy.ServeHTTP(w, r) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/any") - if err != nil { - t.Fatal(err) - } - body, _ := io.ReadAll(res.Body) - res.Body.Close() - - if res.StatusCode != http.StatusBadGateway { - t.Fatalf("want 502, got %d", res.StatusCode) - } - if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) { - t.Fatalf("unexpected body: %s", body) - } - if ct := res.Header.Get("Content-Type"); ct != "application/json" { - t.Fatalf("content-type: want application/json, got %s", ct) - } -} - -func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) { - // Test that context.Canceled errors return 499 without generic error response - proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("")) - if err != nil { - t.Fatal(err) - } - - // Create a canceled context to trigger the cancellation path - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately - - req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx) - rr := httptest.NewRecorder() - - // Directly invoke the ErrorHandler with context.Canceled - proxy.ErrorHandler(rr, req, context.Canceled) - - // Body should be empty for canceled requests (no JSON error response) - body := rr.Body.Bytes() - if len(body) > 0 { - t.Fatalf("expected empty body for canceled context, got: %s", body) - } -} - -func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) { - // Upstream returns gzipped JSON without Content-Encoding header - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - w.Write(gzipBytes([]byte(`{"upstream":"ok"}`))) - })) - defer upstream.Close() - - proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxy.ServeHTTP(w, r) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - body, _ := io.ReadAll(res.Body) - res.Body.Close() - - expected := []byte(`{"upstream":"ok"}`) - if !bytes.Equal(body, expected) { - t.Fatalf("want decompressed JSON, got: %s", body) - } -} - -func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) { - // Upstream returns plain JSON - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(200) - w.Write([]byte(`{"plain":"json"}`)) - })) - defer upstream.Close() - - proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) - if err != nil { - t.Fatal(err) - } - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxy.ServeHTTP(w, r) - })) - defer srv.Close() - - res, err := http.Get(srv.URL + "/test") - if err != nil { - t.Fatal(err) - } - body, _ := io.ReadAll(res.Body) - res.Body.Close() - - expected := []byte(`{"plain":"json"}`) - if !bytes.Equal(body, expected) { - t.Fatalf("want plain JSON unchanged, got: %s", body) - } -} - -func TestIsStreamingResponse(t *testing.T) { - cases := []struct { - name string - header http.Header - want bool - }{ - { - name: "sse", - header: http.Header{"Content-Type": []string{"text/event-stream"}}, - want: true, - }, - { - name: "chunked_not_streaming", - header: http.Header{"Transfer-Encoding": []string{"chunked"}}, - want: false, // Chunked is transport-level, not streaming - }, - { - name: "normal_json", - header: http.Header{"Content-Type": []string{"application/json"}}, - want: false, - }, - { - name: "empty", - header: http.Header{}, - want: false, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - resp := &http.Response{Header: tc.header} - got := isStreamingResponse(resp) - if got != tc.want { - t.Fatalf("want %v, got %v", tc.want, got) - } - }) - } -} - -func TestFilterBetaFeatures(t *testing.T) { - tests := []struct { - name string - header string - featureToRemove string - expected string - }{ - { - name: "Remove context-1m from middle", - header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20", - featureToRemove: "context-1m-2025-08-07", - expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", - }, - { - name: "Remove context-1m from start", - header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14", - featureToRemove: "context-1m-2025-08-07", - expected: "fine-grained-tool-streaming-2025-05-14", - }, - { - name: "Remove context-1m from end", - header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07", - featureToRemove: "context-1m-2025-08-07", - expected: "fine-grained-tool-streaming-2025-05-14", - }, - { - name: "Feature not present", - header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", - featureToRemove: "context-1m-2025-08-07", - expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", - }, - { - name: "Only feature to remove", - header: "context-1m-2025-08-07", - featureToRemove: "context-1m-2025-08-07", - expected: "", - }, - { - name: "Empty header", - header: "", - featureToRemove: "context-1m-2025-08-07", - expected: "", - }, - { - name: "Header with spaces", - header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20", - featureToRemove: "context-1m-2025-08-07", - expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := filterBetaFeatures(tt.header, tt.featureToRemove) - if result != tt.expected { - t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected) - } - }) - } -} diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go deleted file mode 100644 index 86318119e..000000000 --- a/internal/api/modules/amp/response_rewriter.go +++ /dev/null @@ -1,472 +0,0 @@ -package amp - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body -// It is used to rewrite model names in responses when model mapping is used -// and to keep Amp-compatible response shapes. -type ResponseRewriter struct { - gin.ResponseWriter - body *bytes.Buffer - originalModel string - isStreaming bool - suppressThinking bool - requestToolNames map[string]string -} - -// NewResponseRewriter creates a new response rewriter for model name substitution. -func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { - return &ResponseRewriter{ - ResponseWriter: w, - body: &bytes.Buffer{}, - originalModel: originalModel, - } -} - -func NewResponseRewriterForRequest(w gin.ResponseWriter, originalModel string, requestBody []byte) *ResponseRewriter { - rw := NewResponseRewriter(w, originalModel) - rw.requestToolNames = collectRequestToolNames(requestBody) - return rw -} - -const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap - -func looksLikeSSEChunk(data []byte) bool { - for _, line := range bytes.Split(data, []byte("\n")) { - trimmed := bytes.TrimSpace(line) - if bytes.HasPrefix(trimmed, []byte("data:")) || - bytes.HasPrefix(trimmed, []byte("event:")) { - return true - } - } - return false -} - -func (rw *ResponseRewriter) enableStreaming(reason string) error { - if rw.isStreaming { - return nil - } - rw.isStreaming = true - - if rw.body != nil && rw.body.Len() > 0 { - buf := rw.body.Bytes() - toFlush := make([]byte, len(buf)) - copy(toFlush, buf) - rw.body.Reset() - - if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil { - return err - } - if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } - } - - log.Debugf("amp response rewriter: switched to streaming (%s)", reason) - return nil -} - -func (rw *ResponseRewriter) Write(data []byte) (int, error) { - if !rw.isStreaming && rw.body.Len() == 0 { - contentType := rw.Header().Get("Content-Type") - rw.isStreaming = strings.Contains(contentType, "text/event-stream") || - strings.Contains(contentType, "stream") - } - - if !rw.isStreaming { - if looksLikeSSEChunk(data) { - if err := rw.enableStreaming("sse heuristic"); err != nil { - return 0, err - } - } else if rw.body.Len()+len(data) > maxBufferedResponseBytes { - log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes) - if err := rw.enableStreaming("buffer limit"); err != nil { - return 0, err - } - } - } - - if rw.isStreaming { - rewritten := rw.rewriteStreamChunk(data) - n, err := rw.ResponseWriter.Write(rewritten) - if err == nil { - if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } - } - return n, err - } - return rw.body.Write(data) -} - -func (rw *ResponseRewriter) Flush() { - if rw.isStreaming { - if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } - return - } - if rw.body.Len() > 0 { - rewritten := rw.rewriteModelInResponse(rw.body.Bytes()) - // Update Content-Length to match the rewritten body size, since - // signature injection and model name changes alter the payload length. - rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten))) - if _, err := rw.ResponseWriter.Write(rewritten); err != nil { - log.Warnf("amp response rewriter: failed to write rewritten response: %v", err) - } - } -} - -var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} - -// ampCanonicalToolNames maps tool names to the exact casing expected by the -// Amp mode tool whitelist (case-sensitive match). -var ampCanonicalToolNames = map[string]string{ - "bash": "Bash", - "read": "Read", - "grep": "Grep", - "glob": "glob", - "task": "Task", - "check": "Check", -} - -func collectRequestToolNames(data []byte) map[string]string { - if len(data) == 0 { - return nil - } - parsed := gjson.ParseBytes(data) - names := map[string]string{} - conflicts := map[string]bool{} - record := func(name string) { - if name == "" { - return - } - key := strings.ToLower(name) - if conflicts[key] { - return - } - if existing, exists := names[key]; exists { - if existing != name { - names[key] = "" - conflicts[key] = true - } - return - } - names[key] = name - } - - for _, tool := range parsed.Get("tools").Array() { - record(tool.Get("name").String()) - } - if parsed.Get("tool_choice.type").String() == "tool" { - record(parsed.Get("tool_choice.name").String()) - } - if len(names) == 0 { - return nil - } - return names -} - -func canonicalAmpToolName(name string, requestToolNames map[string]string) (string, bool) { - key := strings.ToLower(name) - if canonical, ok := requestToolNames[key]; ok { - if canonical == "" { - return "", false - } - return canonical, true - } - canonical, ok := ampCanonicalToolNames[key] - return canonical, ok -} - -// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing. -// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash") -// which causes Amp's case-sensitive mode whitelist to reject them. -func normalizeAmpToolNames(data []byte) []byte { - return normalizeAmpToolNamesForRequest(data, nil) -} - -func normalizeAmpToolNamesForRequest(data []byte, requestToolNames map[string]string) []byte { - // Non-streaming: content[].name in tool_use blocks - for index, block := range gjson.GetBytes(data, "content").Array() { - if block.Get("type").String() != "tool_use" { - continue - } - name := block.Get("name").String() - if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical { - path := fmt.Sprintf("content.%d.name", index) - var err error - data, err = sjson.SetBytes(data, path, canonical) - if err != nil { - log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err) - } - } - } - - // Streaming: content_block.name in content_block_start events - if gjson.GetBytes(data, "content_block.type").String() == "tool_use" { - name := gjson.GetBytes(data, "content_block.name").String() - if canonical, ok := canonicalAmpToolName(name, requestToolNames); ok && name != canonical { - var err error - data, err = sjson.SetBytes(data, "content_block.name", canonical) - if err != nil { - log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err) - } - } - } - - return data -} - -func (rw *ResponseRewriter) normalizeToolNames(data []byte) []byte { - return normalizeAmpToolNamesForRequest(data, rw.requestToolNames) -} - -// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks -// in API responses so that the Amp TUI does not crash on P.signature.length. -func ensureAmpSignature(data []byte) []byte { - for index, block := range gjson.GetBytes(data, "content").Array() { - blockType := block.Get("type").String() - if blockType != "tool_use" && blockType != "thinking" { - continue - } - signaturePath := fmt.Sprintf("content.%d.signature", index) - if gjson.GetBytes(data, signaturePath).Exists() { - continue - } - var err error - data, err = sjson.SetBytes(data, signaturePath, "") - if err != nil { - log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err) - break - } - } - - contentBlockType := gjson.GetBytes(data, "content_block.type").String() - if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() { - var err error - data, err = sjson.SetBytes(data, "content_block.signature", "") - if err != nil { - log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err) - } - } - - return data -} - -func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { - if !rw.suppressThinking { - return data - } - if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() { - filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`) - if filtered.Exists() { - originalCount := gjson.GetBytes(data, "content.#").Int() - filteredCount := filtered.Get("#").Int() - if originalCount > filteredCount { - var err error - data, err = sjson.SetBytes(data, "content", filtered.Value()) - if err != nil { - log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err) - } - } - } - } - - return data -} - -func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { - data = ensureAmpSignature(data) - data = rw.normalizeToolNames(data) - data = rw.suppressAmpThinking(data) - if len(data) == 0 { - return data - } - - if rw.originalModel == "" { - return data - } - for _, path := range modelFieldPaths { - if gjson.GetBytes(data, path).Exists() { - data, _ = sjson.SetBytes(data, path, rw.originalModel) - } - } - return data -} - -func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { - lines := bytes.Split(chunk, []byte("\n")) - var out [][]byte - - i := 0 - for i < len(lines) { - line := lines[i] - trimmed := bytes.TrimSpace(line) - - // Case 1: "event:" line - look ahead for its "data:" line - if bytes.HasPrefix(trimmed, []byte("event: ")) { - // Scan forward past blank lines to find the data: line - dataIdx := -1 - for j := i + 1; j < len(lines); j++ { - t := bytes.TrimSpace(lines[j]) - if len(t) == 0 { - continue - } - if bytes.HasPrefix(t, []byte("data: ")) { - dataIdx = j - } - break - } - - if dataIdx >= 0 { - // Found event+data pair - process through rewriter - jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: ")) - if len(jsonData) > 0 && jsonData[0] == '{' { - rewritten := rw.rewriteStreamEvent(jsonData) - if rewritten == nil { - i = dataIdx + 1 - continue - } - // Emit event line - out = append(out, line) - // Emit blank lines between event and data - for k := i + 1; k < dataIdx; k++ { - out = append(out, lines[k]) - } - // Emit rewritten data - out = append(out, append([]byte("data: "), rewritten...)) - i = dataIdx + 1 - continue - } - } - - // No data line found (orphan event from cross-chunk split) - // Pass it through as-is - the data will arrive in the next chunk - out = append(out, line) - i++ - continue - } - - // Case 2: standalone "data:" line (no preceding event: in this chunk) - if bytes.HasPrefix(trimmed, []byte("data: ")) { - jsonData := bytes.TrimPrefix(trimmed, []byte("data: ")) - if len(jsonData) > 0 && jsonData[0] == '{' { - rewritten := rw.rewriteStreamEvent(jsonData) - if rewritten != nil { - out = append(out, append([]byte("data: "), rewritten...)) - } - i++ - continue - } - } - - // Case 3: everything else - out = append(out, line) - i++ - } - - return bytes.Join(out, []byte("\n")) -} - -// rewriteStreamEvent processes a single JSON event in the SSE stream. -// It rewrites model names and ensures signature fields exist. -// NOTE: streaming mode does NOT suppress thinking blocks - they are -// passed through with signature injection to avoid breaking SSE index -// alignment and TUI rendering. -func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { - // Inject empty signature where needed - data = ensureAmpSignature(data) - - // Normalize tool names to canonical casing - data = rw.normalizeToolNames(data) - - // Rewrite model name - if rw.originalModel != "" { - for _, path := range modelFieldPaths { - if gjson.GetBytes(data, path).Exists() { - data, _ = sjson.SetBytes(data, path, rw.originalModel) - } - } - } - - return data -} - -// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures -// and strips the proxy-injected "signature" field from tool_use blocks in the messages -// array before forwarding to the upstream API. -// This prevents 400 errors from the API which requires valid signatures on thinking -// blocks and does not accept a signature field on tool_use blocks. -func SanitizeAmpRequestBody(body []byte) []byte { - messages := gjson.GetBytes(body, "messages") - if !messages.Exists() || !messages.IsArray() { - return body - } - - modified := false - for msgIdx, msg := range messages.Array() { - if msg.Get("role").String() != "assistant" { - continue - } - content := msg.Get("content") - if !content.Exists() || !content.IsArray() { - continue - } - - var keepBlocks []interface{} - contentModified := false - - for _, block := range content.Array() { - blockType := block.Get("type").String() - if blockType == "thinking" { - sig := block.Get("signature") - if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" { - contentModified = true - continue - } - } - - // Use raw JSON to prevent float64 rounding of large integers in tool_use inputs - blockRaw := []byte(block.Raw) - if blockType == "tool_use" && block.Get("signature").Exists() { - blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature") - contentModified = true - } - - // sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw - keepBlocks = append(keepBlocks, json.RawMessage(blockRaw)) - } - - if contentModified { - contentPath := fmt.Sprintf("messages.%d.content", msgIdx) - var err error - if len(keepBlocks) == 0 { - body, err = sjson.SetBytes(body, contentPath, []interface{}{}) - } else { - body, err = sjson.SetBytes(body, contentPath, keepBlocks) - } - if err != nil { - log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err) - continue - } - modified = true - } - } - - if modified { - log.Debugf("Amp RequestSanitizer: sanitized request body") - } - return body -} diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go deleted file mode 100644 index 609942edd..000000000 --- a/internal/api/modules/amp/response_rewriter_test.go +++ /dev/null @@ -1,326 +0,0 @@ -package amp - -import ( - "strings" - "testing" -) - -func TestRewriteModelInResponse_TopLevel(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`) - result := rw.rewriteModelInResponse(input) - - expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}` - if string(result) != expected { - t.Errorf("expected %s, got %s", expected, string(result)) - } -} - -func TestRewriteModelInResponse_ResponseModel(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`) - result := rw.rewriteModelInResponse(input) - - expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}` - if string(result) != expected { - t.Errorf("expected %s, got %s", expected, string(result)) - } -} - -func TestRewriteModelInResponse_ResponseCreated(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`) - result := rw.rewriteModelInResponse(input) - - expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}` - if string(result) != expected { - t.Errorf("expected %s, got %s", expected, string(result)) - } -} - -func TestRewriteModelInResponse_NoModelField(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`) - result := rw.rewriteModelInResponse(input) - - if string(result) != string(input) { - t.Errorf("expected no modification, got %s", string(result)) - } -} - -func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) { - rw := &ResponseRewriter{originalModel: ""} - - input := []byte(`{"model":"gpt-5.3-codex"}`) - result := rw.rewriteModelInResponse(input) - - if string(result) != string(input) { - t.Errorf("expected no modification when originalModel is empty, got %s", string(result)) - } -} - -func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n") - result := rw.rewriteStreamChunk(chunk) - - expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n" - if string(result) != expected { - t.Errorf("expected %s, got %s", expected, string(result)) - } -} - -func TestRewriteStreamChunk_MultipleEvents(t *testing.T) { - rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"} - - chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n") - result := rw.rewriteStreamChunk(chunk) - - if string(result) == string(chunk) { - t.Error("expected response.model to be rewritten in SSE stream") - } - if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) { - t.Errorf("expected rewritten model in output, got %s", string(result)) - } -} - -func TestRewriteStreamChunk_MessageModel(t *testing.T) { - rw := &ResponseRewriter{originalModel: "claude-opus-4.5"} - - chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n") - result := rw.rewriteStreamChunk(chunk) - - expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n" - if string(result) != expected { - t.Errorf("expected %s, got %s", expected, string(result)) - } -} - -func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) { - rw := &ResponseRewriter{} - - chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n") - result := rw.rewriteStreamChunk(chunk) - - // Streaming mode preserves thinking blocks (does NOT suppress them) - // to avoid breaking SSE index alignment and TUI rendering - if !contains(result, []byte(`"content_block":{"type":"thinking"`)) { - t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result)) - } - if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) { - t.Fatalf("expected thinking_delta to be preserved, got %s", string(result)) - } - if !contains(result, []byte(`"type":"content_block_stop","index":0`)) { - t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result)) - } - if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) { - t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result)) - } - // Signature should be injected into both thinking and tool_use blocks - if count := strings.Count(string(result), `"signature":""`); count != 2 { - t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result)) - } -} - -func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) { - input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`) - result := SanitizeAmpRequestBody(input) - - if contains(result, []byte("drop-whitespace")) { - t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result)) - } - if contains(result, []byte("drop-number")) { - t.Fatalf("expected non-string signature block to be removed, got %s", string(result)) - } - if !contains(result, []byte("keep-valid")) { - t.Fatalf("expected valid thinking block to remain, got %s", string(result)) - } - if !contains(result, []byte("keep-text")) { - t.Fatalf("expected non-thinking content to remain, got %s", string(result)) - } -} - -func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) { - input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) - result := SanitizeAmpRequestBody(input) - - if contains(result, []byte(`"signature":""`)) { - t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) - } - if !contains(result, []byte(`"valid-sig"`)) { - t.Fatalf("expected thinking signature to remain, got %s", string(result)) - } - if !contains(result, []byte(`"tool_use"`)) { - t.Fatalf("expected tool_use block to remain, got %s", string(result)) - } -} - -func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) { - input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) - result := SanitizeAmpRequestBody(input) - - if contains(result, []byte("drop-me")) { - t.Fatalf("expected invalid thinking block to be removed, got %s", string(result)) - } - if contains(result, []byte(`"signature"`)) { - t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) - } - if !contains(result, []byte(`"tool_use"`)) { - t.Fatalf("expected tool_use block to remain, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) { - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`) - result := normalizeAmpToolNames(input) - - if !contains(result, []byte(`"name":"Bash"`)) { - t.Errorf("expected bash->Bash, got %s", string(result)) - } - if !contains(result, []byte(`"name":"Read"`)) { - t.Errorf("expected read->Read, got %s", string(result)) - } - if contains(result, []byte(`"name":"bash"`)) { - t.Errorf("expected lowercase bash to be replaced, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_Streaming(t *testing.T) { - input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`) - result := normalizeAmpToolNames(input) - - if !contains(result, []byte(`"name":"Grep"`)) { - t.Errorf("expected grep->Grep in streaming, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) { - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) - result := normalizeAmpToolNames(input) - - if string(result) != string(input) { - t.Errorf("expected no modification for correctly-cased tool, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) { - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) - result := normalizeAmpToolNames(input) - - if string(result) != string(input) { - t.Errorf("expected glob to remain lowercase, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_RequestToolCasing_NonStreaming(t *testing.T) { - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) - result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"}) - - if !contains(result, []byte(`"name":"Glob"`)) { - t.Errorf("expected glob->Glob when request advertised Glob, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_RequestToolCasing_Streaming(t *testing.T) { - input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"glob","id":"toolu_01","input":{}}}`) - result := normalizeAmpToolNamesForRequest(input, map[string]string{"glob": "Glob"}) - - if !contains(result, []byte(`"name":"Glob"`)) { - t.Errorf("expected glob->Glob in streaming when request advertised Glob, got %s", string(result)) - } -} - -func TestResponseRewriter_RequestToolCasingFromBody(t *testing.T) { - requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) - rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) - - result := rw.rewriteModelInResponse(input) - - if !contains(result, []byte(`"name":"Glob"`)) { - t.Errorf("expected request body casing to restore glob->Glob, got %s", string(result)) - } -} - -func TestResponseRewriter_LowercaseNativeRequestPreserved(t *testing.T) { - requestBody := []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}}]}`) - rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) - - result := rw.rewriteModelInResponse(input) - - if string(result) == string(input) { - return - } - if !contains(result, []byte(`"name":"glob"`)) { - t.Errorf("expected lowercase-native request to preserve glob, got %s", string(result)) - } -} - -func TestCollectRequestToolNames_CollisionIgnored(t *testing.T) { - tests := []struct { - requestBody []byte - input []byte - forbidden []byte - }{ - { - requestBody: []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}},{"name":"glob","input_schema":{"type":"object"}}]}`), - input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`), - forbidden: []byte(`"name":"Glob"`), - }, - { - requestBody: []byte(`{"tools":[{"name":"glob","input_schema":{"type":"object"}},{"name":"Glob","input_schema":{"type":"object"}}]}`), - input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`), - forbidden: []byte(`"name":"Glob"`), - }, - { - requestBody: []byte(`{"tools":[{"name":"Bash","input_schema":{"type":"object"}},{"name":"bash","input_schema":{"type":"object"}}]}`), - input: []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}}]}`), - forbidden: []byte(`"name":"Bash"`), - }, - } - - for _, tt := range tests { - rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(tt.requestBody)} - result := rw.rewriteModelInResponse(tt.input) - - if contains(result, tt.forbidden) { - t.Errorf("expected conflicting tool casing not to force %s, got %s", string(tt.forbidden), string(result)) - } - } -} - -func TestResponseRewriter_RequestToolCasingFromBody_Streaming(t *testing.T) { - requestBody := []byte(`{"tools":[{"name":"Glob","input_schema":{"type":"object"}}]}`) - rw := &ResponseRewriter{requestToolNames: collectRequestToolNames(requestBody)} - input := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"glob\",\"id\":\"toolu_01\",\"input\":{}}}\n\n") - - result := rw.rewriteStreamChunk(input) - - if !contains(result, []byte(`"name":"Glob"`)) { - t.Errorf("expected streaming response to restore glob->Glob from request body, got %s", string(result)) - } -} - -func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) { - input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`) - result := normalizeAmpToolNames(input) - - if string(result) != string(input) { - t.Errorf("expected no modification for unknown tool, got %s", string(result)) - } -} - -func contains(data, substr []byte) bool { - for i := 0; i <= len(data)-len(substr); i++ { - if string(data[i:i+len(substr)]) == string(substr) { - return true - } - } - return false -} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go deleted file mode 100644 index 84023d156..000000000 --- a/internal/api/modules/amp/routes.go +++ /dev/null @@ -1,335 +0,0 @@ -package amp - -import ( - "context" - "errors" - "net" - "net/http" - "net/http/httputil" - "strings" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers/claude" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers/gemini" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers/openai" - log "github.com/sirupsen/logrus" -) - -// clientAPIKeyContextKey is the context key used to pass the client API key -// from gin.Context to the request context for SecretSource lookup. -type clientAPIKeyContextKey struct{} - -// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["userApiKey"] -// into the request context so that SecretSource can look it up for per-client upstream routing. -func clientAPIKeyMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // Extract the client API key from gin context (set by AuthMiddleware) - if apiKey, exists := c.Get("userApiKey"); exists { - if keyStr, ok := apiKey.(string); ok && keyStr != "" { - // Inject into request context for SecretSource.Get(ctx) to read - ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr) - c.Request = c.Request.WithContext(ctx) - } - } - c.Next() - } -} - -// getClientAPIKeyFromContext retrieves the client API key from request context. -// Returns empty string if not present. -func getClientAPIKeyFromContext(ctx context.Context) string { - if val := ctx.Value(clientAPIKeyContextKey{}); val != nil { - if keyStr, ok := val.(string); ok { - return keyStr - } - } - return "" -} - -// localhostOnlyMiddleware returns a middleware that dynamically checks the module's -// localhost restriction setting. This allows hot-reload of the restriction without restarting. -func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // Check current setting (hot-reloadable) - if !m.IsRestrictedToLocalhost() { - c.Next() - return - } - - // Use actual TCP connection address (RemoteAddr) to prevent header spoofing - // This cannot be forged by X-Forwarded-For or other client-controlled headers - remoteAddr := c.Request.RemoteAddr - - // RemoteAddr format is "IP:port" or "[IPv6]:port", extract just the IP - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - // Try parsing as raw IP (shouldn't happen with standard HTTP, but be defensive) - host = remoteAddr - } - - // Parse the IP to handle both IPv4 and IPv6 - ip := net.ParseIP(host) - if ip == nil { - log.Warnf("amp management: invalid RemoteAddr %s, denying access", remoteAddr) - c.AbortWithStatusJSON(403, gin.H{ - "error": "Access denied: management routes restricted to localhost", - }) - return - } - - // Check if IP is loopback (127.0.0.1 or ::1) - if !ip.IsLoopback() { - log.Warnf("amp management: non-localhost connection from %s attempted access, denying", remoteAddr) - c.AbortWithStatusJSON(403, gin.H{ - "error": "Access denied: management routes restricted to localhost", - }) - return - } - - c.Next() - } -} - -// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks. -// This overwrites any global CORS headers set by the server. -func noCORSMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - // Remove CORS headers to prevent cross-origin access from browsers - c.Header("Access-Control-Allow-Origin", "") - c.Header("Access-Control-Allow-Methods", "") - c.Header("Access-Control-Allow-Headers", "") - c.Header("Access-Control-Allow-Credentials", "") - - // For OPTIONS preflight, deny with 403 - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(403) - return - } - - c.Next() - } -} - -// managementAvailabilityMiddleware short-circuits management routes when the upstream -// proxy is disabled, preventing noisy localhost warnings and accidental exposure. -func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - if m.getProxy() == nil { - logging.SkipGinRequestLogging(c) - c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{ - "error": "amp upstream proxy not available", - }) - return - } - c.Next() - } -} - -// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere. -func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc { - return func(c *gin.Context) { - path := c.Request.URL.Path - for _, prefix := range prefixes { - if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') { - c.Next() - return - } - } - auth(c) - } -} - -// registerManagementRoutes registers Amp management proxy routes -// These routes proxy through to the Amp control plane for OAuth, user management, etc. -// Uses dynamic middleware and proxy getter for hot-reload support. -// The auth middleware validates Authorization header against configured API keys. -func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) { - ampAPI := engine.Group("/api") - - // Always disable CORS for management routes to prevent browser-based attacks - ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware()) - - // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost()) - ampAPI.Use(m.localhostOnlyMiddleware()) - - // Apply authentication middleware - requires valid API key in Authorization header - var authWithBypass gin.HandlerFunc - if auth != nil { - ampAPI.Use(auth) - authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings") - } - - // Inject client API key into request context for per-client upstream routing - ampAPI.Use(clientAPIKeyMiddleware()) - - // Dynamic proxy handler that uses m.getProxy() for hot-reload support - proxyHandler := func(c *gin.Context) { - // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces - defer func() { - if rec := recover(); rec != nil { - if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) { - // Upstream already wrote the status (often 404) before the client/stream ended. - return - } - panic(rec) - } - }() - - proxy := m.getProxy() - if proxy == nil { - c.JSON(503, gin.H{"error": "amp upstream proxy not available"}) - return - } - proxy.ServeHTTP(c.Writer, c.Request) - } - - // Management routes - these are proxied directly to Amp upstream - ampAPI.Any("/internal", proxyHandler) - ampAPI.Any("/internal/*path", proxyHandler) - ampAPI.Any("/user", proxyHandler) - ampAPI.Any("/user/*path", proxyHandler) - ampAPI.Any("/auth", proxyHandler) - ampAPI.Any("/auth/*path", proxyHandler) - ampAPI.Any("/meta", proxyHandler) - ampAPI.Any("/meta/*path", proxyHandler) - ampAPI.Any("/ads", proxyHandler) - ampAPI.Any("/telemetry", proxyHandler) - ampAPI.Any("/telemetry/*path", proxyHandler) - ampAPI.Any("/threads", proxyHandler) - ampAPI.Any("/threads/*path", proxyHandler) - ampAPI.Any("/thread-actors", proxyHandler) - ampAPI.Any("/otel", proxyHandler) - ampAPI.Any("/otel/*path", proxyHandler) - ampAPI.Any("/tab", proxyHandler) - ampAPI.Any("/tab/*path", proxyHandler) - - // Root-level routes that AMP CLI expects without /api prefix - // These need the same security middleware as the /api/* routes (dynamic for hot-reload) - rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} - if authWithBypass != nil { - rootMiddleware = append(rootMiddleware, authWithBypass) - } - // Add clientAPIKeyMiddleware after auth for per-client upstream routing - rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware()) - engine.GET("/threads", append(rootMiddleware, proxyHandler)...) - engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) - engine.GET("/docs", append(rootMiddleware, proxyHandler)...) - engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...) - engine.GET("/settings", append(rootMiddleware, proxyHandler)...) - engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...) - - engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) - engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) - - // Root-level auth routes for CLI login flow - // Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout - // We proxy all /auth/* to support the complete OAuth flow - engine.Any("/auth", append(rootMiddleware, proxyHandler)...) - engine.Any("/auth/*path", append(rootMiddleware, proxyHandler)...) - - // Google v1beta1 passthrough with OAuth fallback - // AMP CLI uses non-standard paths like /publishers/google/models/... - // We bridge these to our standard Gemini handler to enable local OAuth. - // If no local OAuth is available, falls back to ampcode.com proxy. - geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) - geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) - geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) - geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) - - // Route POST model calls through Gemini bridge with FallbackHandler. - // FallbackHandler checks provider -> mapping -> proxy fallback automatically. - // All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior. - ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) { - if c.Request.Method == "POST" { - if path := c.Param("path"); strings.Contains(path, "/models/") { - // POST with /models/ path -> use Gemini bridge with fallback handler - // FallbackHandler will check provider/mapping and proxy if needed - geminiV1Beta1Handler(c) - return - } - } - // Non-POST or no local provider available -> proxy upstream - proxyHandler(c) - }) -} - -// registerProviderAliases registers /api/provider/{provider}/... routes -// These allow Amp CLI to route requests like: -// -// /api/provider/openai/v1/chat/completions -// /api/provider/anthropic/v1/messages -// /api/provider/google/v1beta/models -func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) { - // Create handler instances for different providers - openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler) - geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) - claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) - openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) - - // Create fallback handler wrapper that forwards to ampcode.com when provider not found - // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime) - // Also includes model mapping support for routing unavailable models to alternatives - fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) - - // Provider-specific routes under /api/provider/:provider - ampProviders := engine.Group("/api/provider") - if auth != nil { - ampProviders.Use(auth) - } - // Inject client API key into request context for per-client upstream routing - ampProviders.Use(clientAPIKeyMiddleware()) - - provider := ampProviders.Group("/:provider") - - // Dynamic models handler - routes to appropriate provider based on path parameter - ampModelsHandler := func(c *gin.Context) { - providerName := strings.ToLower(c.Param("provider")) - - switch providerName { - case "anthropic": - claudeCodeHandlers.ClaudeModels(c) - case "google": - geminiHandlers.GeminiModels(c) - default: - // Default to OpenAI-compatible (works for openai, groq, cerebras, etc.) - openaiHandlers.OpenAIModels(c) - } - } - - // Root-level routes (for providers that omit /v1, like groq/cerebras) - // Wrap handlers with fallback logic to forward to ampcode.com when provider not found - provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check) - provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) - - // /v1 routes (OpenAI/Claude-compatible endpoints) - v1Amp := provider.Group("/v1") - { - v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback - - // OpenAI-compatible endpoints with fallback - v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) - - // Claude/Anthropic-compatible endpoints with fallback - v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages)) - v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens)) - } - - // /v1beta routes (Gemini native API) - // Note: Gemini handler extracts model from URL path, so fallback logic needs special handling - v1betaAmp := provider.Group("/v1beta") - { - v1betaAmp.GET("/models", geminiHandlers.GeminiModels) - v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler)) - v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler) - } -} diff --git a/internal/api/modules/amp/routes_test.go b/internal/api/modules/amp/routes_test.go deleted file mode 100644 index a500f8150..000000000 --- a/internal/api/modules/amp/routes_test.go +++ /dev/null @@ -1,382 +0,0 @@ -package amp - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" -) - -func TestRegisterManagementRoutes(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Create module with proxy for testing - m := &AmpModule{ - restrictToLocalhost: false, // disable localhost restriction for tests - } - - // Create a mock proxy that tracks calls - proxyCalled := false - mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxyCalled = true - w.WriteHeader(200) - w.Write([]byte("proxied")) - })) - defer mockProxy.Close() - - // Create real proxy to mock server - proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource("")) - m.setProxy(proxy) - - base := &handlers.BaseAPIHandler{} - m.registerManagementRoutes(r, base, nil) - srv := httptest.NewServer(r) - defer srv.Close() - - managementPaths := []struct { - path string - method string - }{ - {"/api/internal", http.MethodGet}, - {"/api/internal/some/path", http.MethodGet}, - {"/api/user", http.MethodGet}, - {"/api/user/profile", http.MethodGet}, - {"/api/auth", http.MethodGet}, - {"/api/auth/login", http.MethodGet}, - {"/api/meta", http.MethodGet}, - {"/api/telemetry", http.MethodGet}, - {"/api/threads", http.MethodGet}, - {"/api/thread-actors", http.MethodPost}, - {"/threads/", http.MethodGet}, - {"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix) - {"/api/otel", http.MethodGet}, - {"/api/tab", http.MethodGet}, - {"/api/tab/some/path", http.MethodGet}, - {"/auth", http.MethodGet}, // Root-level auth route - {"/auth/cli-login", http.MethodGet}, // CLI login flow - {"/auth/callback", http.MethodGet}, // OAuth callback - // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST - {"/api/provider/google/v1beta1/models", http.MethodGet}, - {"/api/provider/google/v1beta1/models", http.MethodPost}, - } - - for _, path := range managementPaths { - t.Run(path.path, func(t *testing.T) { - proxyCalled = false - req, err := http.NewRequest(path.method, srv.URL+path.path, nil) - if err != nil { - t.Fatalf("failed to build request: %v", err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - t.Fatalf("route %s not registered", path.path) - } - if !proxyCalled { - t.Fatalf("proxy handler not called for %s", path.path) - } - }) - } -} - -func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Minimal base handler setup (no need to initialize, just check routing) - base := &handlers.BaseAPIHandler{} - - // Track if auth middleware was called - authCalled := false - authMiddleware := func(c *gin.Context) { - authCalled = true - c.Header("X-Auth", "ok") - // Abort with success to avoid calling the actual handler (which needs full setup) - c.AbortWithStatus(http.StatusOK) - } - - m := &AmpModule{authMiddleware_: authMiddleware} - m.registerProviderAliases(r, base, authMiddleware) - - paths := []struct { - path string - method string - }{ - {"/api/provider/openai/models", http.MethodGet}, - {"/api/provider/anthropic/models", http.MethodGet}, - {"/api/provider/google/models", http.MethodGet}, - {"/api/provider/groq/models", http.MethodGet}, - {"/api/provider/openai/chat/completions", http.MethodPost}, - {"/api/provider/anthropic/v1/messages", http.MethodPost}, - {"/api/provider/google/v1beta/models", http.MethodGet}, - } - - for _, tc := range paths { - t.Run(tc.path, func(t *testing.T) { - authCalled = false - req := httptest.NewRequest(tc.method, tc.path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code == http.StatusNotFound { - t.Fatalf("route %s %s not registered", tc.method, tc.path) - } - if !authCalled { - t.Fatalf("auth middleware not executed for %s", tc.path) - } - if w.Header().Get("X-Auth") != "ok" { - t.Fatalf("auth middleware header not set for %s", tc.path) - } - }) - } -} - -func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - base := &handlers.BaseAPIHandler{} - - m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} - m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) - - providers := []string{"openai", "anthropic", "google", "groq", "cerebras"} - - for _, provider := range providers { - t.Run(provider, func(t *testing.T) { - path := "/api/provider/" + provider + "/models" - req := httptest.NewRequest(http.MethodGet, path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - // Should not 404 - if w.Code == http.StatusNotFound { - t.Fatalf("models route not found for provider: %s", provider) - } - }) - } -} - -func TestRegisterProviderAliases_V1Routes(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - base := &handlers.BaseAPIHandler{} - - m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} - m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) - - v1Paths := []struct { - path string - method string - }{ - {"/api/provider/openai/v1/models", http.MethodGet}, - {"/api/provider/openai/v1/chat/completions", http.MethodPost}, - {"/api/provider/openai/v1/completions", http.MethodPost}, - {"/api/provider/anthropic/v1/messages", http.MethodPost}, - {"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost}, - } - - for _, tc := range v1Paths { - t.Run(tc.path, func(t *testing.T) { - req := httptest.NewRequest(tc.method, tc.path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code == http.StatusNotFound { - t.Fatalf("v1 route %s %s not registered", tc.method, tc.path) - } - }) - } -} - -func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - base := &handlers.BaseAPIHandler{} - - m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }} - m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) - - v1betaPaths := []struct { - path string - method string - }{ - {"/api/provider/google/v1beta/models", http.MethodGet}, - {"/api/provider/google/v1beta/models/generateContent", http.MethodPost}, - } - - for _, tc := range v1betaPaths { - t.Run(tc.path, func(t *testing.T) { - req := httptest.NewRequest(tc.method, tc.path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code == http.StatusNotFound { - t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path) - } - }) - } -} - -func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) { - // Test that routes still register even if auth middleware is nil (fallback behavior) - gin.SetMode(gin.TestMode) - r := gin.New() - - base := &handlers.BaseAPIHandler{} - - m := &AmpModule{authMiddleware_: nil} // No auth middleware - m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }) - - req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - // Should still work (with fallback no-op auth) - if w.Code == http.StatusNotFound { - t.Fatal("routes should register even without auth middleware") - } -} - -func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Create module with localhost restriction enabled - m := &AmpModule{ - restrictToLocalhost: true, - } - - // Apply dynamic localhost-only middleware - r.Use(m.localhostOnlyMiddleware()) - r.GET("/test", func(c *gin.Context) { - c.String(http.StatusOK, "ok") - }) - - tests := []struct { - name string - remoteAddr string - forwardedFor string - expectedStatus int - description string - }{ - { - name: "spoofed_header_remote_connection", - remoteAddr: "192.168.1.100:12345", - forwardedFor: "127.0.0.1", - expectedStatus: http.StatusForbidden, - description: "Spoofed X-Forwarded-For header should be ignored", - }, - { - name: "real_localhost_ipv4", - remoteAddr: "127.0.0.1:54321", - forwardedFor: "", - expectedStatus: http.StatusOK, - description: "Real localhost IPv4 connection should work", - }, - { - name: "real_localhost_ipv6", - remoteAddr: "[::1]:54321", - forwardedFor: "", - expectedStatus: http.StatusOK, - description: "Real localhost IPv6 connection should work", - }, - { - name: "remote_ipv4", - remoteAddr: "203.0.113.42:8080", - forwardedFor: "", - expectedStatus: http.StatusForbidden, - description: "Remote IPv4 connection should be blocked", - }, - { - name: "remote_ipv6", - remoteAddr: "[2001:db8::1]:9090", - forwardedFor: "", - expectedStatus: http.StatusForbidden, - description: "Remote IPv6 connection should be blocked", - }, - { - name: "spoofed_localhost_ipv6", - remoteAddr: "203.0.113.42:8080", - forwardedFor: "::1", - expectedStatus: http.StatusForbidden, - description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.RemoteAddr = tt.remoteAddr - if tt.forwardedFor != "" { - req.Header.Set("X-Forwarded-For", tt.forwardedFor) - } - - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != tt.expectedStatus { - t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code) - } - }) - } -} - -func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - - // Create module with localhost restriction initially enabled - m := &AmpModule{ - restrictToLocalhost: true, - } - - // Apply dynamic localhost-only middleware - r.Use(m.localhostOnlyMiddleware()) - r.GET("/test", func(c *gin.Context) { - c.String(http.StatusOK, "ok") - }) - - // Test 1: Remote IP should be blocked when restriction is enabled - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.RemoteAddr = "192.168.1.100:12345" - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("Expected 403 when restriction enabled, got %d", w.Code) - } - - // Test 2: Hot-reload - disable restriction - m.setRestrictToLocalhost(false) - - req = httptest.NewRequest(http.MethodGet, "/test", nil) - req.RemoteAddr = "192.168.1.100:12345" - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected 200 after disabling restriction, got %d", w.Code) - } - - // Test 3: Hot-reload - re-enable restriction - m.setRestrictToLocalhost(true) - - req = httptest.NewRequest(http.MethodGet, "/test", nil) - req.RemoteAddr = "192.168.1.100:12345" - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code) - } -} diff --git a/internal/api/modules/amp/secret.go b/internal/api/modules/amp/secret.go deleted file mode 100644 index 512d263d0..000000000 --- a/internal/api/modules/amp/secret.go +++ /dev/null @@ -1,248 +0,0 @@ -package amp - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - log "github.com/sirupsen/logrus" -) - -// SecretSource provides Amp API keys with configurable precedence and caching -type SecretSource interface { - Get(ctx context.Context) (string, error) -} - -// cachedSecret holds a secret value with expiration -type cachedSecret struct { - value string - expiresAt time.Time -} - -// MultiSourceSecret implements precedence-based secret lookup: -// 1. Explicit config value (highest priority) -// 2. Environment variable AMP_API_KEY -// 3. File-based secret (lowest priority) -type MultiSourceSecret struct { - explicitKey string - envKey string - filePath string - cacheTTL time.Duration - - mu sync.RWMutex - cache *cachedSecret -} - -// NewMultiSourceSecret creates a secret source with precedence and caching -func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret { - if cacheTTL == 0 { - cacheTTL = 5 * time.Minute // Default 5 minute cache - } - - home, _ := os.UserHomeDir() - filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json") - - return &MultiSourceSecret{ - explicitKey: strings.TrimSpace(explicitKey), - envKey: "AMP_API_KEY", - filePath: filePath, - cacheTTL: cacheTTL, - } -} - -// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing) -func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret { - if cacheTTL == 0 { - cacheTTL = 5 * time.Minute - } - - return &MultiSourceSecret{ - explicitKey: strings.TrimSpace(explicitKey), - envKey: "AMP_API_KEY", - filePath: filePath, - cacheTTL: cacheTTL, - } -} - -// Get retrieves the Amp API key using precedence: config > env > file -// Results are cached for cacheTTL duration to avoid excessive file reads -func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) { - // Precedence 1: Explicit config key (highest priority, no caching needed) - if s.explicitKey != "" { - return s.explicitKey, nil - } - - // Precedence 2: Environment variable - if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" { - return envValue, nil - } - - // Precedence 3: File-based secret (lowest priority, cached) - // Check cache first - s.mu.RLock() - if s.cache != nil && time.Now().Before(s.cache.expiresAt) { - value := s.cache.value - s.mu.RUnlock() - return value, nil - } - s.mu.RUnlock() - - // Cache miss or expired - read from file - key, err := s.readFromFile() - if err != nil { - // Cache empty result to avoid repeated file reads on missing files - s.updateCache("") - return "", err - } - - // Cache the result - s.updateCache(key) - return key, nil -} - -// readFromFile reads the Amp API key from the secrets file -func (s *MultiSourceSecret) readFromFile() (string, error) { - content, err := os.ReadFile(s.filePath) - if err != nil { - if os.IsNotExist(err) { - return "", nil // Missing file is not an error, just no key available - } - return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err) - } - - var secrets map[string]string - if err := json.Unmarshal(content, &secrets); err != nil { - return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err) - } - - key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"]) - return key, nil -} - -// updateCache updates the cached secret value -func (s *MultiSourceSecret) updateCache(value string) { - s.mu.Lock() - defer s.mu.Unlock() - s.cache = &cachedSecret{ - value: value, - expiresAt: time.Now().Add(s.cacheTTL), - } -} - -// InvalidateCache clears the cached secret, forcing a fresh read on next Get -func (s *MultiSourceSecret) InvalidateCache() { - s.mu.Lock() - defer s.mu.Unlock() - s.cache = nil -} - -// UpdateExplicitKey refreshes the config-provided key and clears cache. -func (s *MultiSourceSecret) UpdateExplicitKey(key string) { - if s == nil { - return - } - s.mu.Lock() - s.explicitKey = strings.TrimSpace(key) - s.cache = nil - s.mu.Unlock() -} - -// StaticSecretSource returns a fixed API key (for testing) -type StaticSecretSource struct { - key string -} - -// NewStaticSecretSource creates a secret source with a fixed key -func NewStaticSecretSource(key string) *StaticSecretSource { - return &StaticSecretSource{key: strings.TrimSpace(key)} -} - -// Get returns the static API key -func (s *StaticSecretSource) Get(ctx context.Context) (string, error) { - return s.key, nil -} - -// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping. -// When a request context contains a client API key that matches a configured mapping, -// the corresponding upstream key is returned. Otherwise, falls back to the default source. -type MappedSecretSource struct { - defaultSource SecretSource - mu sync.RWMutex - lookup map[string]string // clientKey -> upstreamKey -} - -// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source. -func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource { - return &MappedSecretSource{ - defaultSource: defaultSource, - lookup: make(map[string]string), - } -} - -// Get retrieves the Amp API key, checking per-client mappings first. -// If the request context contains a client API key that matches a configured mapping, -// returns the corresponding upstream key. Otherwise, falls back to the default source. -func (s *MappedSecretSource) Get(ctx context.Context) (string, error) { - // Try to get client API key from request context - clientKey := getClientAPIKeyFromContext(ctx) - if clientKey != "" { - s.mu.RLock() - if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" { - s.mu.RUnlock() - return upstreamKey, nil - } - s.mu.RUnlock() - } - - // Fall back to default source - return s.defaultSource.Get(ctx) -} - -// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries. -// If the same client key appears in multiple entries, logs a warning and uses the first one. -func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) { - newLookup := make(map[string]string) - - for _, entry := range entries { - upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey) - if upstreamKey == "" { - continue - } - for _, clientKey := range entry.APIKeys { - trimmedKey := strings.TrimSpace(clientKey) - if trimmedKey == "" { - continue - } - if _, exists := newLookup[trimmedKey]; exists { - // Log warning for duplicate client key, first one wins - log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.") - continue - } - newLookup[trimmedKey] = upstreamKey - } - } - - s.mu.Lock() - s.lookup = newLookup - s.mu.Unlock() -} - -// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable). -func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) { - if ms, ok := s.defaultSource.(*MultiSourceSecret); ok { - ms.UpdateExplicitKey(key) - } -} - -// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable). -func (s *MappedSecretSource) InvalidateCache() { - if ms, ok := s.defaultSource.(*MultiSourceSecret); ok { - ms.InvalidateCache() - } -} diff --git a/internal/api/modules/amp/secret_test.go b/internal/api/modules/amp/secret_test.go deleted file mode 100644 index 17a75b15d..000000000 --- a/internal/api/modules/amp/secret_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package amp - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - log "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" -) - -func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { - ctx := context.Background() - - cases := []struct { - name string - configKey string - envKey string - fileJSON string - want string - }{ - {"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"}, - {"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, - {"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, - {"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, - {"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, - {"missing_file_returns_empty", "", "", "", ""}, - {"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""}, - } - - for _, tc := range cases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - tmpDir := t.TempDir() - secretsPath := filepath.Join(tmpDir, "secrets.json") - - if tc.fileJSON != "" { - if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil { - t.Fatal(err) - } - } - - t.Setenv("AMP_API_KEY", tc.envKey) - - s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond) - got, err := s.Get(ctx) - if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) { - t.Fatalf("unexpected error: %v", err) - } - if got != tc.want { - t.Fatalf("want %q, got %q", tc.want, got) - } - }) - } -} - -func TestMultiSourceSecret_CacheBehavior(t *testing.T) { - ctx := context.Background() - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - - // Initial value - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { - t.Fatal(err) - } - - s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond) - - // First read - should return v1 - got1, err := s.Get(ctx) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if got1 != "v1" { - t.Fatalf("expected v1, got %s", got1) - } - - // Change file; within TTL we should still see v1 (cached) - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil { - t.Fatal(err) - } - got2, _ := s.Get(ctx) - if got2 != "v1" { - t.Fatalf("cache hit expected v1, got %s", got2) - } - - // After TTL expires, should see v2 - time.Sleep(60 * time.Millisecond) - got3, _ := s.Get(ctx) - if got3 != "v2" { - t.Fatalf("cache miss expected v2, got %s", got3) - } - - // Invalidate forces re-read immediately - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil { - t.Fatal(err) - } - s.InvalidateCache() - got4, _ := s.Get(ctx) - if got4 != "v3" { - t.Fatalf("invalidate expected v3, got %s", got4) - } -} - -func TestMultiSourceSecret_FileHandling(t *testing.T) { - ctx := context.Background() - - t.Run("missing_file_no_error", func(t *testing.T) { - s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond) - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("expected no error for missing file, got: %v", err) - } - if got != "" { - t.Fatalf("expected empty string, got %q", got) - } - }) - - t.Run("invalid_json", func(t *testing.T) { - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil { - t.Fatal(err) - } - - s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) - _, err := s.Get(ctx) - if err == nil { - t.Fatal("expected error for invalid JSON") - } - }) - - t.Run("missing_key_in_json", func(t *testing.T) { - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil { - t.Fatal(err) - } - - s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "" { - t.Fatalf("expected empty string for missing key, got %q", got) - } - }) - - t.Run("empty_key_value", func(t *testing.T) { - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil { - t.Fatal(err) - } - - s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) - got, _ := s.Get(ctx) - if got != "" { - t.Fatalf("expected empty after trim, got %q", got) - } - }) -} - -func TestMultiSourceSecret_Concurrency(t *testing.T) { - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "secrets.json") - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil { - t.Fatal(err) - } - - s := NewMultiSourceSecretWithPath("", p, 5*time.Second) - ctx := context.Background() - - // Spawn many goroutines calling Get concurrently - const goroutines = 50 - const iterations = 100 - - var wg sync.WaitGroup - errors := make(chan error, goroutines) - - for i := 0; i < goroutines; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < iterations; j++ { - val, err := s.Get(ctx) - if err != nil { - errors <- err - return - } - if val != "concurrent" { - errors <- err - return - } - } - }() - } - - wg.Wait() - close(errors) - - for err := range errors { - t.Errorf("concurrency error: %v", err) - } -} - -func TestStaticSecretSource(t *testing.T) { - ctx := context.Background() - - t.Run("returns_provided_key", func(t *testing.T) { - s := NewStaticSecretSource("test-key-123") - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "test-key-123" { - t.Fatalf("want test-key-123, got %q", got) - } - }) - - t.Run("trims_whitespace", func(t *testing.T) { - s := NewStaticSecretSource(" test-key ") - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "test-key" { - t.Fatalf("want test-key, got %q", got) - } - }) - - t.Run("empty_string", func(t *testing.T) { - s := NewStaticSecretSource("") - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "" { - t.Fatalf("want empty string, got %q", got) - } - }) -} - -func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) { - // Test that missing file results are cached to avoid repeated file reads - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "nonexistent.json") - - s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) - ctx := context.Background() - - // First call - file doesn't exist, should cache empty result - got1, err := s.Get(ctx) - if err != nil { - t.Fatalf("expected no error for missing file, got: %v", err) - } - if got1 != "" { - t.Fatalf("expected empty string, got %q", got1) - } - - // Create the file now - if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil { - t.Fatal(err) - } - - // Second call - should still return empty (cached), not read the new file - got2, _ := s.Get(ctx) - if got2 != "" { - t.Fatalf("cache should return empty, got %q", got2) - } - - // After TTL expires, should see the new value - time.Sleep(110 * time.Millisecond) - got3, _ := s.Get(ctx) - if got3 != "new-value" { - t.Fatalf("after cache expiry, expected new-value, got %q", got3) - } -} - -func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) { - defaultSource := NewStaticSecretSource("default") - s := NewMappedSecretSource(defaultSource) - s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ - { - UpstreamAPIKey: "u1", - APIKeys: []string{"k1"}, - }, - }) - - ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "u1" { - t.Fatalf("want u1, got %q", got) - } - - ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2") - got, err = s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "default" { - t.Fatalf("want default fallback, got %q", got) - } -} - -func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) { - defaultSource := NewStaticSecretSource("default") - s := NewMappedSecretSource(defaultSource) - s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ - { - UpstreamAPIKey: "u1", - APIKeys: []string{"k1"}, - }, - { - UpstreamAPIKey: "u2", - APIKeys: []string{"k1"}, - }, - }) - - ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") - got, err := s.Get(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "u1" { - t.Fatalf("want u1 (first wins), got %q", got) - } -} - -func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) { - hook := test.NewLocal(log.StandardLogger()) - defer hook.Reset() - - defaultSource := NewStaticSecretSource("default") - s := NewMappedSecretSource(defaultSource) - s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ - { - UpstreamAPIKey: "u1", - APIKeys: []string{"k1"}, - }, - { - UpstreamAPIKey: "u2", - APIKeys: []string{"k1"}, - }, - }) - - foundWarning := false - for _, entry := range hook.AllEntries() { - if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." { - foundWarning = true - break - } - } - if !foundWarning { - t.Fatal("expected warning log for duplicate client key, but none was found") - } -} diff --git a/internal/api/server.go b/internal/api/server.go index 1d7bd28b9..834604abc 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -15,7 +15,6 @@ import ( "net/http" "os" "path/filepath" - "reflect" "sort" "strings" "sync" @@ -26,8 +25,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/access" managementHandlers "github.com/router-for-me/CLIProxyAPI/v7/internal/api/handlers/management" "github.com/router-for-me/CLIProxyAPI/v7/internal/api/middleware" - "github.com/router-for-me/CLIProxyAPI/v7/internal/api/modules" - ampmodule "github.com/router-for-me/CLIProxyAPI/v7/internal/api/modules/amp" "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/home" @@ -222,9 +219,6 @@ type Server struct { // management handler mgmt *managementHandlers.Handler - // ampModule is the Amp routing module for model mapping hot-reload - ampModule *ampmodule.AmpModule - // pluginHost owns dynamic plugin Management API route dispatch. pluginHost *pluginhost.Host @@ -358,18 +352,6 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // Setup routes s.setupRoutes() - // Register Amp module using V2 interface with Context - s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) - ctx := modules.Context{ - Engine: engine, - BaseHandler: s.handlers, - Config: cfg, - AuthMiddleware: AuthMiddleware(accessManager), - } - if err := modules.RegisterModule(ctx, s.ampModule); err != nil { - log.Errorf("Failed to register Amp module: %v", err) - } - // Apply additional router configurators from options if optionState.routerConfigurator != nil { optionState.routerConfigurator(engine, s.handlers, cfg) @@ -692,30 +674,6 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) - mgmt.GET("/ampcode", s.mgmt.GetAmpCode) - mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) - mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) - mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) - mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) - mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) - mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) - mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) - mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) - mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) - mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) - mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) - mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) - mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) - mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) - mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) - mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) - mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) - mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) - mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys) - mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys) - mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys) - mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys) - mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) @@ -1627,19 +1585,6 @@ func (s *Server) UpdateClients(cfg *config.Config) { } s.refreshPluginManagementRoutes() - // Notify Amp module only when Amp config has changed. - ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) - if ampConfigChanged { - if s.ampModule != nil { - log.Debugf("triggering amp module config update") - if err := s.ampModule.OnConfigUpdated(cfg); err != nil { - log.Errorf("failed to update Amp module config: %v", err) - } - } else { - log.Warnf("amp module is nil, skipping config update") - } - } - // Count client sources from configuration and auth store. authEntries := 0 if cfg != nil && !cfg.Home.Enabled { diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f71deacd7..a694883f5 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -293,72 +293,6 @@ func TestHomeEnabledHidesManagementEndpointsAndControlPanel(t *testing.T) { }) } -func TestAmpProviderModelRoutes(t *testing.T) { - testCases := []struct { - name string - path string - wantStatus int - wantContains string - }{ - { - name: "openai root models", - path: "/api/provider/openai/models", - wantStatus: http.StatusOK, - wantContains: `"object":"list"`, - }, - { - name: "groq root models", - path: "/api/provider/groq/models", - wantStatus: http.StatusOK, - wantContains: `"object":"list"`, - }, - { - name: "openai models", - path: "/api/provider/openai/v1/models", - wantStatus: http.StatusOK, - wantContains: `"object":"list"`, - }, - { - name: "anthropic models", - path: "/api/provider/anthropic/v1/models", - wantStatus: http.StatusOK, - wantContains: `"data"`, - }, - { - name: "google models v1", - path: "/api/provider/google/v1/models", - wantStatus: http.StatusOK, - wantContains: `"models"`, - }, - { - name: "google models v1beta", - path: "/api/provider/google/v1beta/models", - wantStatus: http.StatusOK, - wantContains: `"models"`, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - server := newTestServer(t) - - req := httptest.NewRequest(http.MethodGet, tc.path, nil) - req.Header.Set("Authorization", "Bearer test-key") - - rr := httptest.NewRecorder() - server.engine.ServeHTTP(rr, req) - - if rr.Code != tc.wantStatus { - t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String()) - } - if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) { - t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body) - } - }) - } -} - func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) { modelRegistry := registry.GetGlobalRegistry() clientID := "test-client-version-catalog" diff --git a/internal/config/config.go b/internal/config/config.go index 38283e14e..ffcb9c9c3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -135,9 +135,6 @@ type Config struct { // Used for services that use Vertex AI-style paths but with simple API key authentication. VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` - // AmpCode contains Amp CLI upstream configuration, management restrictions, and model mappings. - AmpCode AmpCode `yaml:"ampcode" json:"ampcode"` - // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` @@ -146,7 +143,7 @@ type Config struct { // gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai. // // NOTE: This does not apply to existing per-credential model alias features under: - // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. + // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, and vertex-api-key. OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"` // Payload defines default and override rules for provider payload parameters. @@ -322,8 +319,7 @@ type RoutingConfig struct { // SessionAffinity enables universal session-sticky routing for all clients. // Session IDs are extracted from multiple sources: // metadata.user_id (Claude Code session format), X-Session-ID, Session_id (Codex), - // X-Amp-Thread-Id (Amp CLI thread), X-Client-Request-Id (PI), metadata.user_id, - // conversation_id, or message hash. + // X-Client-Request-Id (PI), metadata.user_id, conversation_id, or message hash. // Automatic failover is always enabled when bound auth becomes unavailable. SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"` @@ -342,63 +338,6 @@ type OAuthModelAlias struct { Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"` } -// AmpModelMapping defines a model name mapping for Amp CLI requests. -// When Amp requests a model that isn't available locally, this mapping -// allows routing to an alternative model that IS available. -type AmpModelMapping struct { - // From is the model name that Amp CLI requests (e.g., "claude-opus-4.5"). - From string `yaml:"from" json:"from"` - - // To is the target model name to route to (e.g., "claude-sonnet-4"). - // The target model must have available providers in the registry. - To string `yaml:"to" json:"to"` - - // Regex indicates whether the 'from' field should be interpreted as a regular - // expression for matching model names. When true, this mapping is evaluated - // after exact matches and in the order provided. Defaults to false (exact match). - Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"` -} - -// AmpCode groups Amp CLI integration settings including upstream routing, -// optional overrides, management route restrictions, and model fallback mappings. -type AmpCode struct { - // UpstreamURL defines the upstream Amp control plane used for non-provider calls. - UpstreamURL string `yaml:"upstream-url" json:"upstream-url"` - - // UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls. - UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` - - // UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys. - // When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey - // is used for the upstream Amp request. - UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"` - - // RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) - // to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by - // browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient). - RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"` - - // ModelMappings defines model name mappings for Amp CLI requests. - // When Amp requests a model that isn't available locally, these mappings - // allow routing to an alternative model that IS available. - ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` - - // ForceModelMappings when true, model mappings take precedence over local API keys. - // When false (default), local API keys are used first if available. - ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` -} - -// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key. -// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey -// is used for the upstream Amp request. -type AmpUpstreamAPIKeyEntry struct { - // UpstreamAPIKey is the API key to use when proxying to the Amp upstream. - UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` - - // APIKeys are the client API keys (from top-level api-keys) that map to this upstream key. - APIKeys []string `yaml:"api-keys" json:"api-keys"` -} - // PayloadConfig defines default and override parameter rules applied to provider payloads. type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload. @@ -740,7 +679,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.DisableImageGeneration = DisableImageGenerationOff cfg.Pprof.Enable = false cfg.Pprof.Addr = DefaultPprofAddr - cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { @@ -763,9 +701,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) { // cfg.legacyMigrationPending = true // } - // if cfg.migrateLegacyAmpConfig(&legacy) { - // cfg.legacyMigrationPending = true - // } // } // Hash remote management key if plaintext is detected (nested) @@ -1216,7 +1151,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error { // Remove deprecated sections before merging back the sanitized config. removeLegacyAuthBlock(original.Content[0]) removeLegacyOpenAICompatAPIKeys(original.Content[0]) - removeLegacyAmpKeys(original.Content[0]) + removeRemovedIntegrationKeys(original.Content[0]) removeLegacyGenerativeLanguageKeys(original.Content[0]) pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models") @@ -1894,12 +1829,8 @@ func normalizeCollectionNodeStyles(node *yaml.Node) { // Legacy migration helpers (move deprecated config keys into structured fields). type legacyConfigData struct { - LegacyGeminiKeys []string `yaml:"generative-language-api-key"` - OpenAICompat []legacyOpenAICompatibility `yaml:"openai-compatibility"` - AmpUpstreamURL string `yaml:"amp-upstream-url"` - AmpUpstreamAPIKey string `yaml:"amp-upstream-api-key"` - AmpRestrictManagement *bool `yaml:"amp-restrict-management-to-localhost"` - AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings"` + LegacyGeminiKeys []string `yaml:"generative-language-api-key"` + OpenAICompat []legacyOpenAICompatibility `yaml:"openai-compatibility"` } type legacyOpenAICompatibility struct { @@ -2012,34 +1943,6 @@ func findOpenAICompatTarget(entries []OpenAICompatibility, legacyName, legacyBas return nil } -func (cfg *Config) migrateLegacyAmpConfig(legacy *legacyConfigData) bool { - if cfg == nil || legacy == nil { - return false - } - changed := false - if cfg.AmpCode.UpstreamURL == "" { - if val := strings.TrimSpace(legacy.AmpUpstreamURL); val != "" { - cfg.AmpCode.UpstreamURL = val - changed = true - } - } - if cfg.AmpCode.UpstreamAPIKey == "" { - if val := strings.TrimSpace(legacy.AmpUpstreamAPIKey); val != "" { - cfg.AmpCode.UpstreamAPIKey = val - changed = true - } - } - if legacy.AmpRestrictManagement != nil { - cfg.AmpCode.RestrictManagementToLocalhost = *legacy.AmpRestrictManagement - changed = true - } - if len(cfg.AmpCode.ModelMappings) == 0 && len(legacy.AmpModelMappings) > 0 { - cfg.AmpCode.ModelMappings = append([]AmpModelMapping(nil), legacy.AmpModelMappings...) - changed = true - } - return changed -} - func removeLegacyOpenAICompatAPIKeys(root *yaml.Node) { if root == nil || root.Kind != yaml.MappingNode { return @@ -2059,10 +1962,11 @@ func removeLegacyOpenAICompatAPIKeys(root *yaml.Node) { } } -func removeLegacyAmpKeys(root *yaml.Node) { +func removeRemovedIntegrationKeys(root *yaml.Node) { if root == nil || root.Kind != yaml.MappingNode { return } + removeMapKey(root, "ampcode") removeMapKey(root, "amp-upstream-url") removeMapKey(root, "amp-upstream-api-key") removeMapKey(root, "amp-restrict-management-to-localhost") diff --git a/internal/config/parse.go b/internal/config/parse.go index 393b629ce..b097976c0 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -28,7 +28,6 @@ func ParseConfigBytes(data []byte) (*Config, error) { cfg.DisableImageGeneration = DisableImageGenerationOff cfg.Pprof.Enable = false cfg.Pprof.Addr = DefaultPprofAddr - cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository if err := yaml.Unmarshal(data, &cfg); err != nil { diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index 689ea13a9..a4c9aa085 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -25,7 +25,6 @@ var aiAPIPrefixes = []string{ "/v1/messages", "/v1/responses", "/v1beta/models/", - "/api/provider/", "/backend-api/codex/", } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index b306b5a76..22de9183d 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -101,10 +101,9 @@ var oauthToolRenameMap = map[string]string{ // The reverse map is now computed per-request in remapOAuthToolNames so that // only names the client actually caused us to rewrite are restored on the // response. A global reverse map — as used previously — corrupted responses -// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase -// alongside `glob` lowercase; the request flagged renames via `glob→Glob`, -// then the global reverse map incorrectly rewrote every `Bash` in the -// response to `bash`, causing Amp to reject the tool_use as unknown). +// for clients that sent mixed casing (e.g. `Bash` TitleCase alongside `glob` +// lowercase; the request flagged renames via `glob` -> `Glob`, then the global +// reverse map incorrectly rewrote every `Bash` in the response to `bash`). // oauthToolsToRemove lists tool names that must be stripped from OAuth requests // even after remapping. Currently empty — all tools are mapped instead of removed. @@ -212,7 +211,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r // Enforce Anthropic's cache_control block limit (max 4 breakpoints per request). // Cloaking and ensureCacheControl may push the total over 4 when the client - // (e.g. Amp CLI) already sends multiple cache_control blocks. + // already sends multiple cache_control blocks. body = enforceCacheControlLimit(body, 4) // Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05. @@ -1135,9 +1134,9 @@ func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefi // client-supplied original name. Callers MUST pass this map to the reverse // functions so only names the client actually caused us to rewrite are restored // on the response. A global reverse map (the previous implementation) incorrectly -// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`) +// rewrote names the client originally sent in TitleCase (e.g. `Bash`) // when any OTHER tool in the same request triggered a forward rename (e.g. -// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash` +// `glob` -> `Glob`), because the global reverse map contained `Bash` -> `bash` // regardless of what the client originally sent. func remapOAuthToolNames(body []byte) ([]byte, map[string]string) { reverseMap := make(map[string]string, len(oauthToolRenameMap)) diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 2ac32ebde..c54ea598a 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -2191,8 +2191,7 @@ func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) { // must pass through unchanged) and a lowercase tool that we forward-rename. // Before the fix, triggering ANY forward rename caused the reverse pass to // lowercase every TitleCase tool in the response using a global reverse map, -// corrupting tool names the client originally sent in TitleCase (notably Amp -// CLI's `Bash`, which its registry lookup cannot find as `bash`). +// corrupting tool names the client originally sent in TitleCase. func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) { body := []byte(`{"tools":[` + `{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` + diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go index 9707f39cf..3009c1f76 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go @@ -388,8 +388,7 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) { } func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) { - // When the Amp client sends functionResponse with an empty name, - // fixCLIToolResponse should backfill it from the corresponding functionCall. + // Empty functionResponse names are backfilled from the corresponding functionCall. input := `{ "model": "gemini-3-pro-preview", "request": { diff --git a/internal/translator/gemini/gemini/gemini_gemini_request.go b/internal/translator/gemini/gemini/gemini_gemini_request.go index 6c36dfd80..4d7e0b7d3 100644 --- a/internal/translator/gemini/gemini/gemini_gemini_request.go +++ b/internal/translator/gemini/gemini/gemini_gemini_request.go @@ -87,7 +87,7 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte } // Backfill empty functionResponse.name from the preceding functionCall.name. - // Amp may send function responses with empty names; the Gemini API rejects these. + // Some clients send function responses with empty names; the Gemini API rejects these. out = backfillEmptyFunctionResponseNames(out) out = common.AttachDefaultSafetySettings(out, "safetySettings") diff --git a/internal/tui/config_tab.go b/internal/tui/config_tab.go index ff9ad040e..6ac42639b 100644 --- a/internal/tui/config_tab.go +++ b/internal/tui/config_tab.go @@ -356,22 +356,10 @@ func (m configTabModel) parseConfig(cfg map[string]any) []configField { // WebSocket auth fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil}) - // AMP settings - if amp, ok := cfg["ampcode"].(map[string]any); ok { - upstreamURL := getString(amp, "upstream-url") - upstreamAPIKey := getString(amp, "upstream-api-key") - fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL}) - fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey}) - fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil}) - } - return fields } func fieldSection(apiPath string) string { - if strings.HasPrefix(apiPath, "ampcode/") { - return T("section_ampcode") - } if strings.HasPrefix(apiPath, "quota-exceeded/") { return T("section_quota") } @@ -404,10 +392,3 @@ func getBoolNested(m map[string]any, keys ...string) bool { } return false } - -func maskIfNotEmpty(s string) string { - if s == "" { - return T("not_set") - } - return maskKey(s) -} diff --git a/internal/tui/i18n.go b/internal/tui/i18n.go index a4c0ac165..64227b34f 100644 --- a/internal/tui/i18n.go +++ b/internal/tui/i18n.go @@ -131,7 +131,6 @@ var zhStrings = map[string]string{ "section_quota": "配额超限处理", "section_routing": "路由", "section_websocket": "WebSocket", - "section_ampcode": "AMP Code", "section_other": "其他", // ── Auth Files ── @@ -283,7 +282,6 @@ var enStrings = map[string]string{ "section_quota": "Quota Exceeded Handling", "section_routing": "Routing", "section_websocket": "WebSocket", - "section_ampcode": "AMP Code", "section_other": "Other", // ── Auth Files ── diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 023b2f0be..0efc42bfe 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -228,39 +228,6 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { } } - // AmpCode settings (redacted where needed) - oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL) - newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL) - if oldAmpURL != newAmpURL { - changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL)) - } - oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey) - newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey) - switch { - case oldAmpKey == "" && newAmpKey != "": - changes = append(changes, "ampcode.upstream-api-key: added") - case oldAmpKey != "" && newAmpKey == "": - changes = append(changes, "ampcode.upstream-api-key: removed") - case oldAmpKey != newAmpKey: - changes = append(changes, "ampcode.upstream-api-key: updated") - } - if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost { - changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost)) - } - oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings) - newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings) - if oldMappings.hash != newMappings.hash { - changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count)) - } - if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings { - changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings)) - } - oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys) - newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys) - if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) { - changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount)) - } - if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { changes = append(changes, entries...) } @@ -410,43 +377,3 @@ func formatProxyURL(raw string) string { } return scheme + "://" + host } - -func equalStringSet(a, b []string) bool { - if len(a) == 0 && len(b) == 0 { - return true - } - aSet := make(map[string]struct{}, len(a)) - for _, k := range a { - aSet[strings.TrimSpace(k)] = struct{}{} - } - bSet := make(map[string]struct{}, len(b)) - for _, k := range b { - bSet[strings.TrimSpace(k)] = struct{}{} - } - if len(aSet) != len(bSet) { - return false - } - for k := range aSet { - if _, ok := bSet[k]; !ok { - return false - } - } - return true -} - -// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality. -// Comparison is done by count and content (upstream key and client keys). -func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) { - return false - } - if !equalStringSet(a[i].APIKeys, b[i].APIKeys) { - return false - } - } - return true -} diff --git a/internal/watcher/diff/config_diff_test.go b/internal/watcher/diff/config_diff_test.go index 192791ea7..e80bf0176 100644 --- a/internal/watcher/diff/config_diff_test.go +++ b/internal/watcher/diff/config_diff_test.go @@ -14,11 +14,6 @@ func TestBuildConfigChangeDetails(t *testing.T) { GeminiKey: []config.GeminiKey{ {APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}}, }, - AmpCode: config.AmpCode{ - UpstreamURL: "http://old-upstream", - ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}}, - RestrictManagementToLocalhost: false, - }, RemoteManagement: config.RemoteManagement{ AllowRemote: false, SecretKey: "old", @@ -46,14 +41,6 @@ func TestBuildConfigChangeDetails(t *testing.T) { GeminiKey: []config.GeminiKey{ {APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}}, }, - AmpCode: config.AmpCode{ - UpstreamURL: "http://new-upstream", - RestrictManagementToLocalhost: true, - ModelMappings: []config.AmpModelMapping{ - {From: "from-old", To: "to-old"}, - {From: "from-new", To: "to-new"}, - }, - }, RemoteManagement: config.RemoteManagement{ AllowRemote: true, SecretKey: "new", @@ -87,8 +74,6 @@ func TestBuildConfigChangeDetails(t *testing.T) { expectContains(t, details, "port: 8080 -> 9090") expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new") expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)") - expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream") - expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)") expectContains(t, details, "remote-management.allow-remote: false -> true") expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true") expectContains(t, details, "remote-management.secret-key: updated") @@ -108,7 +93,7 @@ func TestBuildConfigChangeDetails_NoChanges(t *testing.T) { } } -func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) { +func TestBuildConfigChangeDetails_GeminiVertexHeaders(t *testing.T) { oldCfg := &config.Config{ GeminiKey: []config.GeminiKey{ {APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}}, @@ -116,10 +101,6 @@ func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing VertexCompatAPIKey: []config.VertexCompatKey{ {APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}}, }, - AmpCode: config.AmpCode{ - ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, - ForceModelMappings: false, - }, } newCfg := &config.Config{ GeminiKey: []config.GeminiKey{ @@ -128,17 +109,11 @@ func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing VertexCompatAPIKey: []config.VertexCompatKey{ {APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}}, }, - AmpCode: config.AmpCode{ - ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}}, - ForceModelMappings: true, - }, } details := BuildConfigChangeDetails(oldCfg, newCfg) expectContains(t, details, "gemini[0].headers: updated") expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)") - expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)") - expectContains(t, details, "ampcode.force-model-mappings: false -> true") } func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) { @@ -192,9 +167,6 @@ func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) { SDKConfig: sdkconfig.SDKConfig{ APIKeys: []string{"a"}, }, - AmpCode: config.AmpCode{ - UpstreamAPIKey: "", - }, RemoteManagement: config.RemoteManagement{ SecretKey: "", }, @@ -203,9 +175,6 @@ func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) { SDKConfig: sdkconfig.SDKConfig{ APIKeys: []string{"a", "b", "c"}, }, - AmpCode: config.AmpCode{ - UpstreamAPIKey: "new-key", - }, RemoteManagement: config.RemoteManagement{ SecretKey: "new-secret", }, @@ -213,7 +182,6 @@ func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) { details := BuildConfigChangeDetails(oldCfg, newCfg) expectContains(t, details, "api-keys count: 1 -> 3") - expectContains(t, details, "ampcode.upstream-api-key: added") expectContains(t, details, "remote-management.secret-key: created") } @@ -232,7 +200,6 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false}, ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}}, CodexKey: []config.CodexKey{{APIKey: "x1"}}, - AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false}, RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"}, SDKConfig: sdkconfig.SDKConfig{ RequestLog: false, @@ -262,11 +229,6 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { {APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}}, {APIKey: "x2"}, }, - AmpCode: config.AmpCode{ - UpstreamAPIKey: "", - RestrictManagementToLocalhost: true, - ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, - }, RemoteManagement: config.RemoteManagement{ DisableControlPanel: true, DisableAutoUpdatePanel: true, @@ -303,8 +265,6 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { expectContains(t, details, "api-keys count: 1 -> 2") expectContains(t, details, "claude-api-key count: 1 -> 2") expectContains(t, details, "codex-api-key count: 1 -> 2") - expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true") - expectContains(t, details, "ampcode.upstream-api-key: removed") expectContains(t, details, "remote-management.disable-control-panel: false -> true") expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true") expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo") @@ -336,13 +296,6 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { VertexCompatAPIKey: []config.VertexCompatKey{ {APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}}, }, - AmpCode: config.AmpCode{ - UpstreamURL: "http://amp-old", - UpstreamAPIKey: "old-key", - RestrictManagementToLocalhost: false, - ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, - ForceModelMappings: false, - }, RemoteManagement: config.RemoteManagement{ AllowRemote: false, DisableControlPanel: false, @@ -390,13 +343,6 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { VertexCompatAPIKey: []config.VertexCompatKey{ {APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}}, }, - AmpCode: config.AmpCode{ - UpstreamURL: "http://amp-new", - UpstreamAPIKey: "", - RestrictManagementToLocalhost: true, - ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}}, - ForceModelMappings: true, - }, RemoteManagement: config.RemoteManagement{ AllowRemote: true, DisableControlPanel: true, @@ -464,11 +410,6 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { expectContains(t, changes, "vertex[0].api-key: updated") expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)") expectContains(t, changes, "vertex[0].headers: updated") - expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new") - expectContains(t, changes, "ampcode.upstream-api-key: removed") - expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true") - expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)") - expectContains(t, changes, "ampcode.force-model-mappings: false -> true") expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)") expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)") expectContains(t, changes, "remote-management.allow-remote: false -> true") @@ -503,26 +444,19 @@ func TestFormatProxyURL(t *testing.T) { } } -func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) { +func TestBuildConfigChangeDetails_RemoteManagementSecretUpdated(t *testing.T) { oldCfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamAPIKey: "old", - }, RemoteManagement: config.RemoteManagement{ SecretKey: "old", }, } newCfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamAPIKey: "new", - }, RemoteManagement: config.RemoteManagement{ SecretKey: "new", }, } changes := BuildConfigChangeDetails(oldCfg, newCfg) - expectContains(t, changes, "ampcode.upstream-api-key: updated") expectContains(t, changes, "remote-management.secret-key: updated") } diff --git a/internal/watcher/diff/oauth_excluded.go b/internal/watcher/diff/oauth_excluded.go index d63206284..05cc3ffa8 100644 --- a/internal/watcher/diff/oauth_excluded.go +++ b/internal/watcher/diff/oauth_excluded.go @@ -1,13 +1,9 @@ package diff import ( - "crypto/sha256" - "encoding/hex" "fmt" "sort" "strings" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" ) type ExcludedModelsSummary struct { @@ -86,33 +82,3 @@ func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string sort.Strings(affected) return changes, affected } - -type AmpModelMappingsSummary struct { - hash string - count int -} - -// SummarizeAmpModelMappings hashes Amp model mappings for change detection. -func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary { - if len(mappings) == 0 { - return AmpModelMappingsSummary{} - } - entries := make([]string, 0, len(mappings)) - for _, mapping := range mappings { - from := strings.TrimSpace(mapping.From) - to := strings.TrimSpace(mapping.To) - if from == "" && to == "" { - continue - } - entries = append(entries, from+"->"+to) - } - if len(entries) == 0 { - return AmpModelMappingsSummary{} - } - sort.Strings(entries) - sum := sha256.Sum256([]byte(strings.Join(entries, "|"))) - return AmpModelMappingsSummary{ - hash: hex.EncodeToString(sum[:]), - count: len(entries), - } -} diff --git a/internal/watcher/diff/oauth_excluded_test.go b/internal/watcher/diff/oauth_excluded_test.go index 8643f5944..72beac7ee 100644 --- a/internal/watcher/diff/oauth_excluded_test.go +++ b/internal/watcher/diff/oauth_excluded_test.go @@ -39,26 +39,6 @@ func TestDiffOAuthExcludedModelChanges(t *testing.T) { } } -func TestSummarizeAmpModelMappings(t *testing.T) { - summary := SummarizeAmpModelMappings([]config.AmpModelMapping{ - {From: "a", To: "A"}, - {From: "b", To: "B"}, - {From: " ", To: " "}, // ignored - }) - if summary.count != 2 { - t.Fatalf("expected 2 entries, got %d", summary.count) - } - if summary.hash == "" { - t.Fatal("expected non-empty hash") - } - if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" { - t.Fatalf("expected empty summary for nil input, got %+v", empty) - } - if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" { - t.Fatalf("expected blank mappings ignored, got %+v", blank) - } -} - func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) { out := SummarizeOAuthExcludedModels(map[string][]string{ "ProvA": {"X"}, diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 7842295c5..911e489bd 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -289,7 +289,7 @@ func setServiceTierMetadata(meta map[string]any, rawJSON []byte) { // headersFromContext extracts the original HTTP request headers from the gin context // embedded in the provided context. This allows session affinity selectors to read -// client headers like X-Amp-Thread-Id. +// client-provided session headers. func headersFromContext(ctx context.Context) http.Header { if ctx == nil { return nil diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 19d1843fe..0dcb32d93 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -472,11 +472,10 @@ func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAff // 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority // 2. X-Session-ID header // 3. Session_id header (Codex) -// 4. X-Amp-Thread-Id header (Amp CLI thread ID) -// 5. X-Client-Request-Id header (PI) -// 6. metadata.user_id (non-Claude Code format) -// 7. conversation_id field in request body -// 8. Stable hash from first few messages content (fallback) +// 4. X-Client-Request-Id header (PI) +// 5. metadata.user_id (non-Claude Code format) +// 6. conversation_id field in request body +// 7. Stable hash from first few messages content (fallback) // // Note: The cache key includes provider, session ID, and model to handle cases where // a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview) @@ -574,11 +573,10 @@ func (s *SessionAffinitySelector) InvalidateAuth(authID string) { // 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients // 2. X-Session-ID header // 3. Session_id header (Codex) -// 4. X-Amp-Thread-Id header (Amp CLI thread ID) -// 5. X-Client-Request-Id header (PI) -// 6. metadata.user_id (non-Claude Code format) -// 7. conversation_id field in request body -// 8. Stable hash from first few messages content (fallback) +// 4. X-Client-Request-Id header (PI) +// 5. metadata.user_id (non-Claude Code format) +// 6. conversation_id field in request body +// 7. Stable hash from first few messages content (fallback) func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string { primary, _ := extractSessionIDs(headers, payload, metadata) return primary @@ -624,14 +622,7 @@ func extractSessionIDs(headers http.Header, payload []byte, metadata map[string] } } - // 4. X-Amp-Thread-Id header (Amp CLI thread ID) - if headers != nil { - if tid := headers.Get("X-Amp-Thread-Id"); tid != "" { - return "amp:" + tid, "" - } - } - - // 5. X-Client-Request-Id header (PI) + // 4. X-Client-Request-Id header (PI) if headers != nil { if rid := headers.Get("X-Client-Request-Id"); rid != "" { return "clientreq:" + rid, "" diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index 99231bdf7..c2d752a49 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -816,60 +816,6 @@ func TestExtractSessionID_CodexSessionIDPriorityOverClientRequestID(t *testing.T } } -func TestExtractSessionID_AmpThreadId(t *testing.T) { - t.Parallel() - - headers := make(http.Header) - headers.Set("X-Amp-Thread-Id", "T-7873e6bd-6354-4a9a-be2c-c7702c6e1b64") - - got := ExtractSessionID(headers, nil, nil) - want := "amp:T-7873e6bd-6354-4a9a-be2c-c7702c6e1b64" - if got != want { - t.Errorf("ExtractSessionID() with X-Amp-Thread-Id = %q, want %q", got, want) - } -} - -func TestExtractSessionID_AmpThreadIdPriorityOverClientRequestID(t *testing.T) { - t.Parallel() - - headers := make(http.Header) - headers.Set("X-Amp-Thread-Id", "T-priority-test") - headers.Set("X-Client-Request-Id", "pi-session-123") - - got := ExtractSessionID(headers, nil, nil) - want := "amp:T-priority-test" - if got != want { - t.Errorf("ExtractSessionID() = %q, want %q (X-Amp-Thread-Id should take priority over X-Client-Request-Id)", got, want) - } -} - -// TestExtractSessionID_AmpThreadIdLowerPriority verifies X-Amp-Thread-Id is lower -// priority than Claude Code metadata.user_id but higher than conversation_id. -func TestExtractSessionID_AmpThreadIdPriority(t *testing.T) { - t.Parallel() - - // X-Amp-Thread-Id should be used when no Claude Code user_id is present - headers := make(http.Header) - headers.Set("X-Amp-Thread-Id", "T-priority-test") - - payload := []byte(`{"conversation_id":"conv-12345"}`) - got := ExtractSessionID(headers, payload, nil) - want := "amp:T-priority-test" - if got != want { - t.Errorf("ExtractSessionID() = %q, want %q (Amp thread ID should take priority over conversation_id)", got, want) - } - - // Claude Code user_id should take priority over X-Amp-Thread-Id - headers2 := make(http.Header) - headers2.Set("X-Amp-Thread-Id", "T-priority-test") - payload2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`) - got2 := ExtractSessionID(headers2, payload2, nil) - want2 := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344" - if got2 != want2 { - t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should take priority over Amp thread ID)", got2, want2) - } -} - // TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally // ignored for session affinity (it's auto-generated per-request, causing cache misses). func TestExtractSessionID_IdempotencyKey(t *testing.T) { diff --git a/sdk/config/config.go b/sdk/config/config.go index d39e512de..0be8c8b5f 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -13,7 +13,6 @@ type Config = internalconfig.Config type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement -type AmpCode = internalconfig.AmpCode type OAuthModelAlias = internalconfig.OAuthModelAlias type PayloadConfig = internalconfig.PayloadConfig type PayloadRule = internalconfig.PayloadRule diff --git a/test/amp_management_test.go b/test/amp_management_test.go deleted file mode 100644 index 6c694db6f..000000000 --- a/test/amp_management_test.go +++ /dev/null @@ -1,915 +0,0 @@ -package test - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v7/internal/api/handlers/management" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -// newAmpTestHandler creates a test handler with default ampcode configuration. -func newAmpTestHandler(t *testing.T) (*management.Handler, string) { - t.Helper() - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - cfg := &config.Config{ - AmpCode: config.AmpCode{ - UpstreamURL: "https://example.com", - UpstreamAPIKey: "test-api-key-12345", - RestrictManagementToLocalhost: true, - ForceModelMappings: false, - ModelMappings: []config.AmpModelMapping{ - {From: "gpt-4", To: "gemini-pro"}, - }, - }, - } - - if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { - t.Fatalf("failed to write config file: %v", err) - } - - h := management.NewHandler(cfg, configPath, nil) - return h, configPath -} - -// setupAmpRouter creates a test router with all ampcode management endpoints. -func setupAmpRouter(h *management.Handler) *gin.Engine { - r := gin.New() - mgmt := r.Group("/v0/management") - { - mgmt.GET("/ampcode", h.GetAmpCode) - mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL) - mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL) - mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL) - mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) - mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) - mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) - mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys) - mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys) - mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys) - mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys) - mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) - mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) - mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) - mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings) - mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings) - mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings) - mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings) - mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings) - } - return r -} - -// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config. -func TestGetAmpCode(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string]config.AmpCode - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - ampcode := resp["ampcode"] - if ampcode.UpstreamURL != "https://example.com" { - t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL) - } - if len(ampcode.ModelMappings) != 1 { - t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings)) - } -} - -// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL. -func TestGetAmpUpstreamURL(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - if resp["upstream-url"] != "https://example.com" { - t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"]) - } -} - -// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL. -func TestPutAmpUpstreamURL(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": "https://new-upstream.com"}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) - } -} - -// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL. -func TestDeleteAmpUpstreamURL(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key. -func TestGetAmpUpstreamAPIKey(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string]any - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - key := resp["upstream-api-key"].(string) - if key != "test-api-key-12345" { - t.Errorf("expected key %q, got %q", "test-api-key-12345", key) - } -} - -// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key. -func TestPutAmpUpstreamAPIKey(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": "new-secret-key"}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) { - h, configPath := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) - } - - // Verify it was persisted to disk - loaded, err := config.LoadConfig(configPath) - if err != nil { - t.Fatalf("failed to load config from disk: %v", err) - } - if len(loaded.AmpCode.UpstreamAPIKeys) != 1 { - t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys)) - } - entry := loaded.AmpCode.UpstreamAPIKeys[0] - if entry.UpstreamAPIKey != "u1" { - t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey) - } - if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" { - t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys) - } - - // Verify it is returned by GET /ampcode - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - var resp map[string]config.AmpCode - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" { - t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got) - } -} - -func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - // Seed with one entry - putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) - } - - deleteBody := `{"value":[]}` - req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody)) - req.Header.Set("Content-Type", "application/json") - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - var resp map[string][]config.AmpUpstreamAPIKeyEntry - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 { - t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"]) - } -} - -// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. -func TestDeleteAmpUpstreamAPIKey(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting. -func TestGetAmpRestrictManagementToLocalhost(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string]bool - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - if resp["restrict-management-to-localhost"] != true { - t.Error("expected restrict-management-to-localhost to be true") - } -} - -// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting. -func TestPutAmpRestrictManagementToLocalhost(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": false}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings. -func TestGetAmpModelMappings(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - mappings := resp["model-mappings"] - if len(mappings) != 1 { - t.Fatalf("expected 1 mapping, got %d", len(mappings)) - } - if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" { - t.Errorf("unexpected mapping: %+v", mappings[0]) - } -} - -// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings. -func TestPutAmpModelMappings(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) - } -} - -// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones. -func TestPatchAmpModelMappings(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}` - req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) - } -} - -// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field. -func TestDeleteAmpModelMappings_Specific(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": ["gpt-4"]}` - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. -func TestDeleteAmpModelMappings_All(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting. -func TestGetAmpForceModelMappings(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string]bool - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal response: %v", err) - } - - if resp["force-model-mappings"] != false { - t.Error("expected force-model-mappings to be false") - } -} - -// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting. -func TestPutAmpForceModelMappings(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": true}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted. -func TestPutAmpModelMappings_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String()) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - mappings := resp["model-mappings"] - if len(mappings) != 3 { - t.Fatalf("expected 3 mappings, got %d", len(mappings)) - } - - expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"} - for _, m := range mappings { - if expected[m.From] != m.To { - t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To) - } - } -} - -// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly. -func TestPatchAmpModelMappings_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}` - req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PATCH failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - mappings := resp["model-mappings"] - if len(mappings) != 2 { - t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings)) - } - - found := make(map[string]string) - for _, m := range mappings { - found[m.From] = m.To - } - - if found["gpt-4"] != "updated-target" { - t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"]) - } - if found["new-model"] != "new-target" { - t.Errorf("new-model should map to new-target, got %q", found["new-model"]) - } -} - -// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others. -func TestDeleteAmpModelMappings_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - delBody := `{"value": ["a", "c"]}` - req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) - req.Header.Set("Content-Type", "application/json") - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("DELETE failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - mappings := resp["model-mappings"] - if len(mappings) != 1 { - t.Fatalf("expected 1 mapping remaining, got %d", len(mappings)) - } - if mappings[0].From != "b" || mappings[0].To != "2" { - t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To) - } -} - -// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones. -func TestDeleteAmpModelMappings_NonExistent(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - delBody := `{"value": ["non-existent-model"]}` - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if len(resp["model-mappings"]) != 1 { - t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"])) - } -} - -// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings. -func TestPutAmpModelMappings_Empty(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": []}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if len(resp["model-mappings"]) != 0 { - t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) - } -} - -// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state. -func TestPutAmpUpstreamURL_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": "https://new-api.example.com"}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PUT failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["upstream-url"] != "https://new-api.example.com" { - t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"]) - } -} - -// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL. -func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("DELETE failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["upstream-url"] != "" { - t.Errorf("expected empty string, got %q", resp["upstream-url"]) - } -} - -// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state. -func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": "new-secret-api-key-xyz"}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PUT failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["upstream-api-key"] != "new-secret-api-key-xyz" { - t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"]) - } -} - -// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key. -func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("DELETE failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["upstream-api-key"] != "" { - t.Errorf("expected empty string, got %q", resp["upstream-api-key"]) - } -} - -// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction. -func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": false}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PUT failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]bool - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["restrict-management-to-localhost"] != false { - t.Error("expected false after update") - } -} - -// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting. -func TestPutAmpForceModelMappings_VerifyState(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{"value": true}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("PUT failed: status %d", w.Code) - } - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string]bool - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if resp["force-model-mappings"] != true { - t.Error("expected true after update") - } -} - -// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400. -func TestPutBoolField_EmptyObject(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - body := `{}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code) - } -} - -// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET. -func TestComplexMappingsWorkflow(t *testing.T) { - h, _ := newAmpTestHandler(t) - r := setupAmpRouter(h) - - putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}` - req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}` - req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody)) - req.Header.Set("Content-Type", "application/json") - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - delBody := `{"value": ["m1", "m3"]}` - req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody)) - req.Header.Set("Content-Type", "application/json") - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w = httptest.NewRecorder() - r.ServeHTTP(w, req) - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - mappings := resp["model-mappings"] - if len(mappings) != 3 { - t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings)) - } - - expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"} - found := make(map[string]string) - for _, m := range mappings { - found[m.From] = m.To - } - - for from, to := range expected { - if found[from] != to { - t.Errorf("mapping %s: expected %q, got %q", from, to, found[from]) - } - } -} - -// TestNilHandlerGetAmpCode verifies handler works with empty config. -func TestNilHandlerGetAmpCode(t *testing.T) { - cfg := &config.Config{} - h := management.NewHandler(cfg, "", nil) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } -} - -// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config. -func TestEmptyConfigGetAmpModelMappings(t *testing.T) { - cfg := &config.Config{} - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil { - t.Fatalf("failed to write config: %v", err) - } - - h := management.NewHandler(cfg, configPath, nil) - r := setupAmpRouter(h) - - req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } - - var resp map[string][]config.AmpModelMapping - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - - if len(resp["model-mappings"]) != 0 { - t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"])) - } -}