diff --git a/README_CN.md b/README_CN.md index bea12aff0..9db41b2b7 100644 --- a/README_CN.md +++ b/README_CN.md @@ -218,6 +218,10 @@ OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼 一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。 +### [Codex Switch](https://github.com/9ycrooked/CodexSwitch) + +这是一个使用 Tauri 2 + Vue 3 构建的工具,用于管理多个 OpenAI Codex 桌面账户。它可以在已保存的 ChatGPT/Codex 认证配置之间切换,实时查看 5 小时和每周配额使用情况,验证 token 健康状态,查看当前账户详情,并在无需手动复制的情况下导入或保存 auth.json 文件。 + > [!NOTE] > 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。 diff --git a/README_JA.md b/README_JA.md index d432b4845..2f95398d2 100644 --- a/README_JA.md +++ b/README_JA.md @@ -217,6 +217,10 @@ OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです: 上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。 +### [Codex Switch](https://github.com/9ycrooked/CodexSwitch) + +Tauri 2 + Vue 3で構築された、複数のOpenAI Codexデスクトップアカウントを管理するためのツールです。保存済みのChatGPT/Codex認証プロファイルを切り替え、5時間および週次クォータ使用量をリアルタイムで確認し、tokenの状態を検証し、現在のアカウント詳細を表示し、手動コピーなしでauth.jsonファイルをインポートまたは保存できます。 + > [!NOTE] > CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。 diff --git a/cmd/server/home_flag_test.go b/cmd/server/home_flag_test.go deleted file mode 100644 index e98d85f17..000000000 --- a/cmd/server/home_flag_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import "testing" - -func TestParseHomeFlagConfigHostPort(t *testing.T) { - cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret") - if err != nil { - t.Fatalf("parseHomeFlagConfig() error = %v", err) - } - - if !cfg.Enabled { - t.Fatal("Enabled = false, want true") - } - if cfg.Host != "home.example.com" { - t.Fatalf("Host = %q, want home.example.com", cfg.Host) - } - if cfg.Port != 8327 { - t.Fatalf("Port = %d, want 8327", cfg.Port) - } - if cfg.Password != "secret" { - t.Fatalf("Password = %q, want secret", cfg.Password) - } - if cfg.TLS.Enable { - t.Fatal("TLS.Enable = true, want false") - } -} - -func TestParseHomeFlagConfigRediss(t *testing.T) { - cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "") - if err != nil { - t.Fatalf("parseHomeFlagConfig() error = %v", err) - } - - if cfg.Host != "home.example.com" { - t.Fatalf("Host = %q, want home.example.com", cfg.Host) - } - if cfg.Port != 444 { - t.Fatalf("Port = %d, want 444", cfg.Port) - } - if cfg.Password != "url-secret" { - t.Fatalf("Password = %q, want url-secret", cfg.Password) - } - if !cfg.TLS.Enable { - t.Fatal("TLS.Enable = false, want true") - } - if cfg.TLS.ServerName != "home.example.com" { - t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName) - } - if !cfg.TLS.InsecureSkipVerify { - t.Fatal("TLS.InsecureSkipVerify = false, want true") - } - if cfg.TLS.CACert != "C:/certs/ca.pem" { - t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert) - } -} - -func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) { - cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret") - if err != nil { - t.Fatalf("parseHomeFlagConfig() error = %v", err) - } - - if cfg.Password != "flag-secret" { - t.Fatalf("Password = %q, want flag-secret", cfg.Password) - } -} - -func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) { - cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "") - if err != nil { - t.Fatalf("parseHomeFlagConfig() error = %v", err) - } - - if !cfg.DisableClusterDiscovery { - t.Fatal("DisableClusterDiscovery = false, want true") - } -} diff --git a/cmd/server/main.go b/cmd/server/main.go index a42a73242..4181faeca 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,11 +10,9 @@ import ( "fmt" "io" "io/fs" - "net" "net/url" "os" "path/filepath" - "strconv" "strings" "time" @@ -53,120 +51,6 @@ func init() { buildinfo.BuildDate = BuildDate } -func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) { - rawAddr = strings.TrimSpace(rawAddr) - if rawAddr == "" { - return config.HomeConfig{}, fmt.Errorf("address is empty") - } - - if strings.Contains(rawAddr, "://") { - return parseHomeURLConfig(rawAddr, password) - } - - host, portStr, errSplit := net.SplitHostPort(rawAddr) - if errSplit != nil { - return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit) - } - - host = strings.TrimSpace(host) - if host == "" { - return config.HomeConfig{}, fmt.Errorf("host is empty") - } - - port, errPort := parseHomePort(portStr) - if errPort != nil { - return config.HomeConfig{}, errPort - } - - return config.HomeConfig{ - Enabled: true, - Host: host, - Port: port, - Password: password, - }, nil -} - -func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) { - parsed, errParse := url.Parse(rawAddr) - if errParse != nil { - return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse) - } - - scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) - if scheme != "redis" && scheme != "rediss" { - return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme) - } - - host := strings.TrimSpace(parsed.Hostname()) - if host == "" { - return config.HomeConfig{}, fmt.Errorf("host is empty") - } - - port, errPort := parseHomePort(parsed.Port()) - if errPort != nil { - return config.HomeConfig{}, errPort - } - - if password == "" && parsed.User != nil { - if urlPassword, ok := parsed.User.Password(); ok { - password = urlPassword - } - } - - homeCfg := config.HomeConfig{ - Enabled: true, - Host: host, - Port: port, - Password: password, - } - query := parsed.Query() - homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery") - - if scheme == "rediss" { - homeCfg.TLS.Enable = true - homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name")) - homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify") - homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert")) - } - - return homeCfg, nil -} - -func parseHomePort(rawPort string) (int, error) { - rawPort = strings.TrimSpace(rawPort) - if rawPort == "" { - return 0, fmt.Errorf("port is empty") - } - - port, errPort := strconv.Atoi(rawPort) - if errPort != nil || port <= 0 || port > 65535 { - return 0, fmt.Errorf("invalid port %q", rawPort) - } - - return port, nil -} - -func firstHomeQueryValue(values url.Values, keys ...string) string { - for _, key := range keys { - if value := values.Get(key); value != "" { - return value - } - } - return "" -} - -func parseHomeBoolQuery(values url.Values, keys ...string) bool { - for _, key := range keys { - value := strings.TrimSpace(values.Get(key)) - if value == "" { - continue - } - parsed, errParse := strconv.ParseBool(value) - return errParse == nil && parsed - } - return false -} - // main is the entry point of the application. // It parses command-line flags, loads configuration, and starts the appropriate // service based on the provided flags (login, codex-login, or server mode). @@ -188,8 +72,6 @@ func main() { var vertexImportPrefix string var configPath string var password string - var homeAddr string - var homePassword string var homeJWT string var homeDisableClusterDiscovery bool var tuiMode bool @@ -211,10 +93,8 @@ func main() { flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)") flag.StringVar(&password, "password", "", "") - flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)") - flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection") - flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address") + flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home-jwt address") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching") @@ -302,17 +182,6 @@ func main() { } writableBase := util.WritablePath() - // Allow env var fallback for home flags so they can be configured without command args. - if strings.TrimSpace(homeAddr) == "" { - if v, ok := lookupEnv("HOME_ADDR", "home_addr"); ok { - homeAddr = v - } - } - if strings.TrimSpace(homePassword) == "" { - if v, ok := lookupEnv("HOME_PASSWORD", "home_password"); ok { - homePassword = v - } - } if strings.TrimSpace(homeJWT) == "" { if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok { homeJWT = v @@ -426,53 +295,6 @@ func main() { configFilePath = filepath.Join(wd, "config.yaml") } - // Local stores are intentionally disabled when config is loaded from home. - usePostgresStore = false - useObjectStore = false - useGitStore = false - } else if strings.TrimSpace(homeAddr) != "" { - configLoadedFromHome = true - trimmedHomePassword := strings.TrimSpace(homePassword) - homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword) - if errHomeCfg != nil { - log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg) - return - } - if homeDisableClusterDiscovery { - homeCfg.DisableClusterDiscovery = true - } - homeClient := home.New(homeCfg) - defer homeClient.Close() - - ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second) - raw, errGetConfig := homeClient.GetConfig(ctxHome) - cancelHome() - if errGetConfig != nil { - log.Errorf("failed to fetch config from home: %v", errGetConfig) - return - } - - parsed, errParseConfig := config.ParseConfigBytes(raw) - if errParseConfig != nil { - log.Errorf("failed to parse config payload from home: %v", errParseConfig) - return - } - if parsed == nil { - parsed = &config.Config{} - } - parsed.Home = homeCfg - parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config - parsed.UsageStatisticsEnabled = true - cfg = parsed - - // Keep a non-empty config path for downstream components (log paths, management assets, etc), - // but do not require the file to exist when loading config from home. - if strings.TrimSpace(configPath) != "" { - configFilePath = configPath - } else { - configFilePath = filepath.Join(wd, "config.yaml") - } - // Local stores are intentionally disabled when config is loaded from home. usePostgresStore = false useObjectStore = false diff --git a/config.example.yaml b/config.example.yaml index 5327d8e4a..959f1f401 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -11,26 +11,6 @@ tls: cert: "" key: "" -# Optional "home" control plane integration over Redis protocol. -home: - enabled: false - host: "127.0.0.1" - port: 6379 - password: "" - # Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries. - # Useful when Home is behind NAT, Docker networking, or a reverse proxy. - disable-cluster-discovery: false - # Optional TLS for the outbound Redis connection to the home control plane. - # Enable this when connecting through rediss:// or an SSL stream proxy. - tls: - enable: false - # Optional SNI/certificate name override. Leave empty to use the configured home host. - server-name: "" - # Trust a private CA bundle in addition to system roots. - ca-cert: "" - # Only for testing self-signed endpoints; disables certificate verification. - insecure-skip-verify: false - # Management API settings remote-management: # Whether to allow remote (non-localhost) management access. @@ -86,8 +66,8 @@ error-logs-max-files: 10 # When false, disable in-memory usage statistics aggregation usage-statistics-enabled: false -# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP). -# Note: the in-process Redis RESP usage output is disabled when home.enabled is true. +# How long (in seconds) usage queue items are retained in memory for the Management API. +# The local Redis RESP usage output is disabled. # Default: 60. Max: 3600. redis-usage-queue-retention-seconds: 60 diff --git a/internal/api/protocol_multiplexer.go b/internal/api/protocol_multiplexer.go index 607d55a7c..42665ac68 100644 --- a/internal/api/protocol_multiplexer.go +++ b/internal/api/protocol_multiplexer.go @@ -103,20 +103,8 @@ func (s *Server) routeMuxConnection(conn net.Conn, httpListener *muxListener) { } if isRedisRESPPrefix(prefix[0]) { - if s.cfg != nil && s.cfg.Home.Enabled { - if errClose := conn.Close(); errClose != nil { - log.Errorf("failed to close redis connection while home mode is enabled: %v", errClose) - } - return - } - if !s.managementRoutesEnabled.Load() { - if errClose := conn.Close(); errClose != nil { - log.Errorf("failed to close redis connection while management is disabled: %v", errClose) - } - return - } _ = conn.SetReadDeadline(time.Time{}) - s.handleRedisConnection(conn, reader) + s.handleRedisConnection(conn) return } diff --git a/internal/api/redis_queue_protocol.go b/internal/api/redis_queue_protocol.go index f9d412d98..2e86c773f 100644 --- a/internal/api/redis_queue_protocol.go +++ b/internal/api/redis_queue_protocol.go @@ -2,25 +2,11 @@ package api import ( "bufio" - "errors" - "fmt" - "io" "net" - "net/http" - "strconv" - "strings" - "github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue" log "github.com/sirupsen/logrus" ) -const redisUsageChannel = "usage" - -type redisSubscriptionCommand struct { - args []string - err error -} - func isRedisRESPPrefix(prefix byte) bool { switch prefix { case '*', '$', '+', '-', ':': @@ -30,13 +16,11 @@ func isRedisRESPPrefix(prefix byte) bool { } } -func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { - if s == nil || conn == nil || reader == nil { +func (s *Server) handleRedisConnection(conn net.Conn) { + if s == nil || conn == nil { return } - clientIP, localClient := resolveRemoteIP(conn.RemoteAddr()) - authed := false writer := bufio.NewWriter(conn) defer func() { if errClose := conn.Close(); errClose != nil { @@ -44,432 +28,10 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { } }() - flush := func() bool { - if errFlush := writer.Flush(); errFlush != nil { - log.Errorf("redis protocol flush error: %v", errFlush) - return false - } - return true + _ = writeRedisError(writer, "ERR RESP AUTH disabled; use mTLS") + if errFlush := writer.Flush(); errFlush != nil { + log.Errorf("redis protocol flush error: %v", errFlush) } - - if s.cfg != nil && s.cfg.Home.Enabled { - _ = writeRedisError(writer, "ERR redis usage output disabled in home mode") - _ = writer.Flush() - return - } - - for { - if !s.managementRoutesEnabled.Load() { - return - } - - args, err := readRESPArray(reader) - if err != nil { - if !errors.Is(err, io.EOF) { - _ = writeRedisError(writer, "ERR "+err.Error()) - _ = writer.Flush() - } - return - } - if len(args) == 0 { - _ = writeRedisError(writer, "ERR empty command") - if !flush() { - return - } - continue - } - - cmd := strings.ToUpper(strings.TrimSpace(args[0])) - - if cmd != "AUTH" && !authed { - if s.mgmt != nil { - _, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "") - if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") { - _ = writeRedisError(writer, "ERR "+errMsg) - } else { - _ = writeRedisError(writer, "NOAUTH Authentication required.") - } - } else { - _ = writeRedisError(writer, "NOAUTH Authentication required.") - } - if !flush() { - return - } - continue - } - - switch cmd { - case "AUTH": - password, ok := parseAuthPassword(args) - if !ok { - if s.mgmt != nil { - _, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "") - if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") { - _ = writeRedisError(writer, "ERR "+errMsg) - if !flush() { - return - } - continue - } - } - _ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command") - if !flush() { - return - } - continue - } - if s.mgmt == nil { - _ = writeRedisError(writer, "ERR remote management disabled") - if !flush() { - return - } - continue - } - allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password) - if !allowed { - _ = writeRedisError(writer, "ERR "+errMsg) - if !flush() { - return - } - continue - } - authed = true - _ = writeRedisSimpleString(writer, "OK") - if !flush() { - return - } - case "SUBSCRIBE": - if !authed { - _ = writeRedisError(writer, "NOAUTH Authentication required.") - if !flush() { - return - } - continue - } - channel, ok := parseSubscribeChannel(args) - if !ok { - _ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command") - if !flush() { - return - } - continue - } - if !strings.EqualFold(channel, redisUsageChannel) { - _ = writeRedisError(writer, fmt.Sprintf("ERR unsupported channel '%s'", channel)) - if !flush() { - return - } - continue - } - messages, unsubscribe := redisqueue.SubscribeUsage() - if errWrite := writeRedisPubSubSubscribe(writer, redisUsageChannel, 1); errWrite != nil { - unsubscribe() - log.Errorf("redis protocol subscribe response error: %v", errWrite) - return - } - if !flush() { - unsubscribe() - return - } - s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe) - return - case "LPOP", "RPOP": - if !authed { - _ = writeRedisError(writer, "NOAUTH Authentication required.") - if !flush() { - return - } - continue - } - count, hasCount, ok := parsePopCount(args) - if !ok { - _ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command") - if !flush() { - return - } - continue - } - if count <= 0 { - _ = writeRedisError(writer, "ERR value is not an integer or out of range") - if !flush() { - return - } - continue - } - items := redisqueue.PopOldest(count) - if hasCount { - _ = writeRedisArrayOfBulkStrings(writer, items) - if !flush() { - return - } - continue - } - if len(items) == 0 { - _ = writeRedisNilBulkString(writer) - if !flush() { - return - } - continue - } - _ = writeRedisBulkString(writer, items[0]) - if !flush() { - return - } - default: - _ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd))) - if !flush() { - return - } - } - } -} - -func (s *Server) streamRedisUsageSubscription(reader *bufio.Reader, writer *bufio.Writer, messages <-chan []byte, unsubscribe func()) { - if unsubscribe == nil { - return - } - defer unsubscribe() - - done := make(chan struct{}) - defer close(done) - - commands := make(chan redisSubscriptionCommand, 1) - go readRedisSubscriptionCommands(reader, commands, done) - - for { - select { - case msg, ok := <-messages: - if !ok { - return - } - if errWrite := writeRedisPubSubMessage(writer, redisUsageChannel, msg); errWrite != nil { - log.Errorf("redis protocol publish message error: %v", errWrite) - return - } - if errFlush := writer.Flush(); errFlush != nil { - log.Errorf("redis protocol flush error: %v", errFlush) - return - } - case command, ok := <-commands: - if !ok { - return - } - keepOpen := handleRedisSubscriptionCommand(writer, command) - if errFlush := writer.Flush(); errFlush != nil { - log.Errorf("redis protocol flush error: %v", errFlush) - return - } - if !keepOpen { - return - } - } - } -} - -func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSubscriptionCommand, done <-chan struct{}) { - defer close(commands) - - for { - args, err := readRESPArray(reader) - if err != nil { - if !errors.Is(err, io.EOF) { - select { - case commands <- redisSubscriptionCommand{err: err}: - case <-done: - } - } - return - } - select { - case commands <- redisSubscriptionCommand{args: args}: - case <-done: - return - } - } -} - -func handleRedisSubscriptionCommand(writer *bufio.Writer, command redisSubscriptionCommand) bool { - if command.err != nil { - _ = writeRedisError(writer, "ERR "+command.err.Error()) - return false - } - if len(command.args) == 0 { - _ = writeRedisError(writer, "ERR empty command") - return true - } - - cmd := strings.ToUpper(strings.TrimSpace(command.args[0])) - switch cmd { - case "PING": - payload := []byte(nil) - if len(command.args) > 1 { - payload = []byte(command.args[1]) - } - _ = writeRedisPubSubPong(writer, payload) - return true - case "UNSUBSCRIBE": - _ = writeRedisPubSubUnsubscribe(writer, redisUsageChannel, 0) - return false - case "QUIT": - _ = writeRedisSimpleString(writer, "OK") - return false - default: - _ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd))) - return true - } -} - -func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) { - if addr == nil { - return "", false - } - - var host string - switch a := addr.(type) { - case *net.TCPAddr: - if a != nil && a.IP != nil { - if ip4 := a.IP.To4(); ip4 != nil { - host = ip4.String() - } else { - host = a.IP.String() - } - } - default: - host = addr.String() - if h, _, err := net.SplitHostPort(host); err == nil { - host = h - } - host = strings.TrimSpace(host) - if raw, _, ok := strings.Cut(host, "%"); ok { - host = raw - } - if parsed := net.ParseIP(host); parsed != nil { - if ip4 := parsed.To4(); ip4 != nil { - host = ip4.String() - } else { - host = parsed.String() - } - } - } - - host = strings.TrimSpace(host) - localClient = host == "127.0.0.1" || host == "::1" - return host, localClient -} - -func parseAuthPassword(args []string) (string, bool) { - switch len(args) { - case 2: - return args[1], true - case 3: - // Support AUTH by ignoring username for compatibility. - return args[2], true - default: - return "", false - } -} - -func parseSubscribeChannel(args []string) (string, bool) { - if len(args) != 2 { - return "", false - } - return strings.TrimSpace(args[1]), true -} - -func parsePopCount(args []string) (count int, hasCount bool, ok bool) { - if len(args) != 2 && len(args) != 3 { - return 0, false, false - } - if len(args) == 2 { - return 1, false, true - } - parsed, err := strconv.Atoi(strings.TrimSpace(args[2])) - if err != nil { - return 0, true, true - } - return parsed, true, true -} - -func readRESPArray(reader *bufio.Reader) ([]string, error) { - prefix, err := reader.ReadByte() - if err != nil { - return nil, err - } - if prefix != '*' { - return nil, fmt.Errorf("protocol error") - } - line, err := readRESPLine(reader) - if err != nil { - return nil, err - } - count, err := strconv.Atoi(line) - if err != nil || count < 0 { - return nil, fmt.Errorf("protocol error") - } - args := make([]string, 0, count) - for i := 0; i < count; i++ { - value, err := readRESPString(reader) - if err != nil { - return nil, err - } - args = append(args, value) - } - return args, nil -} - -func readRESPString(reader *bufio.Reader) (string, error) { - prefix, err := reader.ReadByte() - if err != nil { - return "", err - } - switch prefix { - case '$': - return readRESPBulkString(reader) - case '+', ':': - return readRESPLine(reader) - default: - return "", fmt.Errorf("protocol error") - } -} - -func readRESPBulkString(reader *bufio.Reader) (string, error) { - line, err := readRESPLine(reader) - if err != nil { - return "", err - } - length, err := strconv.Atoi(line) - if err != nil { - return "", fmt.Errorf("protocol error") - } - if length < 0 { - return "", nil - } - buf := make([]byte, length+2) - if _, err := io.ReadFull(reader, buf); err != nil { - return "", err - } - if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' { - return "", fmt.Errorf("protocol error") - } - return string(buf[:length]), nil -} - -func readRESPLine(reader *bufio.Reader) (string, error) { - line, err := reader.ReadString('\n') - if err != nil { - return "", err - } - line = strings.TrimSuffix(line, "\n") - line = strings.TrimSuffix(line, "\r") - return line, nil -} - -func writeRedisSimpleString(writer *bufio.Writer, value string) error { - if writer == nil { - return net.ErrClosed - } - _, err := writer.WriteString("+" + value + "\r\n") - return err } func writeRedisError(writer *bufio.Writer, message string) error { @@ -479,108 +41,3 @@ func writeRedisError(writer *bufio.Writer, message string) error { _, err := writer.WriteString("-" + message + "\r\n") return err } - -func writeRedisNilBulkString(writer *bufio.Writer) error { - if writer == nil { - return net.ErrClosed - } - _, err := writer.WriteString("$-1\r\n") - return err -} - -func writeRedisBulkString(writer *bufio.Writer, payload []byte) error { - if writer == nil { - return net.ErrClosed - } - if payload == nil { - return writeRedisNilBulkString(writer) - } - if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil { - return err - } - if _, err := writer.Write(payload); err != nil { - return err - } - _, err := writer.WriteString("\r\n") - return err -} - -func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error { - if writer == nil { - return net.ErrClosed - } - if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil { - return err - } - for i := range items { - if err := writeRedisBulkString(writer, items[i]); err != nil { - return err - } - } - return nil -} - -func writeRedisInteger(writer *bufio.Writer, value int) error { - if writer == nil { - return net.ErrClosed - } - _, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n") - return err -} - -func writeRedisArrayHeader(writer *bufio.Writer, count int) error { - if writer == nil { - return net.ErrClosed - } - _, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n") - return err -} - -func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error { - if err := writeRedisArrayHeader(writer, 3); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte(channel)); err != nil { - return err - } - return writeRedisInteger(writer, count) -} - -func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error { - if err := writeRedisArrayHeader(writer, 3); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte(channel)); err != nil { - return err - } - return writeRedisInteger(writer, count) -} - -func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error { - if err := writeRedisArrayHeader(writer, 3); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte("message")); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte(channel)); err != nil { - return err - } - return writeRedisBulkString(writer, payload) -} - -func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error { - if err := writeRedisArrayHeader(writer, 2); err != nil { - return err - } - if err := writeRedisBulkString(writer, []byte("pong")); err != nil { - return err - } - return writeRedisBulkString(writer, payload) -} diff --git a/internal/api/redis_queue_protocol_integration_test.go b/internal/api/redis_queue_protocol_integration_test.go index 8547e0403..b74a84ca6 100644 --- a/internal/api/redis_queue_protocol_integration_test.go +++ b/internal/api/redis_queue_protocol_integration_test.go @@ -3,14 +3,9 @@ package api import ( "bufio" "bytes" - "encoding/json" "errors" "fmt" - "io" "net" - "net/http" - "net/http/httptest" - "strconv" "strings" "testing" "time" @@ -18,18 +13,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue" ) -type remoteAddrConn struct { - net.Conn - remoteAddr net.Addr -} - -func (c *remoteAddrConn) RemoteAddr() net.Addr { - if c == nil { - return nil - } - return c.remoteAddr -} - func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) { t.Helper() @@ -86,17 +69,6 @@ func readTestRESPLine(r *bufio.Reader) (string, error) { return strings.TrimSuffix(line, "\r\n"), nil } -func readTestRESPSimpleString(r *bufio.Reader) (string, error) { - prefix, err := r.ReadByte() - if err != nil { - return "", err - } - if prefix != '+' { - return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix) - } - return readTestRESPLine(r) -} - func readTestRESPError(r *bufio.Reader) (string, error) { prefix, err := r.ReadByte() if err != nil { @@ -108,171 +80,6 @@ func readTestRESPError(r *bufio.Reader) (string, error) { return readTestRESPLine(r) } -func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) { - prefix, err := r.ReadByte() - if err != nil { - return nil, err - } - if prefix != '$' { - return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix) - } - - line, err := readTestRESPLine(r) - if err != nil { - return nil, err - } - length, err := strconv.Atoi(line) - if err != nil { - return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err) - } - if length == -1 { - return nil, nil - } - if length < -1 { - return nil, fmt.Errorf("invalid bulk string length %d", length) - } - - payload := make([]byte, length+2) - if _, err := io.ReadFull(r, payload); err != nil { - return nil, err - } - if payload[length] != '\r' || payload[length+1] != '\n' { - return nil, fmt.Errorf("invalid bulk string terminator") - } - return payload[:length], nil -} - -func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) { - prefix, err := r.ReadByte() - if err != nil { - return nil, err - } - if prefix != '*' { - return nil, fmt.Errorf("expected array prefix '*', got %q", prefix) - } - - line, err := readTestRESPLine(r) - if err != nil { - return nil, err - } - count, err := strconv.Atoi(line) - if err != nil { - return nil, fmt.Errorf("invalid array length %q: %v", line, err) - } - if count < 0 { - return nil, fmt.Errorf("invalid array length %d", count) - } - - out := make([][]byte, 0, count) - for i := 0; i < count; i++ { - item, err := readTestRESPBulkString(r) - if err != nil { - return nil, err - } - out = append(out, item) - } - return out, nil -} - -func readTestRESPInteger(r *bufio.Reader) (int, error) { - prefix, err := r.ReadByte() - if err != nil { - return 0, err - } - if prefix != ':' { - return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix) - } - - line, err := readTestRESPLine(r) - if err != nil { - return 0, err - } - value, err := strconv.Atoi(line) - if err != nil { - return 0, fmt.Errorf("invalid integer %q: %v", line, err) - } - return value, nil -} - -func readTestRESPArrayHeader(r *bufio.Reader) (int, error) { - prefix, err := r.ReadByte() - if err != nil { - return 0, err - } - if prefix != '*' { - return 0, fmt.Errorf("expected array prefix '*', got %q", prefix) - } - - line, err := readTestRESPLine(r) - if err != nil { - return 0, err - } - count, err := strconv.Atoi(line) - if err != nil { - return 0, fmt.Errorf("invalid array length %q: %v", line, err) - } - if count < 0 { - return 0, fmt.Errorf("invalid array length %d", count) - } - return count, nil -} - -func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) { - count, err := readTestRESPArrayHeader(r) - if err != nil { - return "", 0, err - } - if count != 3 { - return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count) - } - - kind, err := readTestRESPBulkString(r) - if err != nil { - return "", 0, err - } - if string(kind) != "subscribe" { - return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind)) - } - - channel, err := readTestRESPBulkString(r) - if err != nil { - return "", 0, err - } - subscriptions, err := readTestRESPInteger(r) - if err != nil { - return "", 0, err - } - return string(channel), subscriptions, nil -} - -func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) { - count, err := readTestRESPArrayHeader(r) - if err != nil { - return "", nil, err - } - if count != 3 { - return "", nil, fmt.Errorf("message array length = %d, want 3", count) - } - - kind, err := readTestRESPBulkString(r) - if err != nil { - return "", nil, err - } - if string(kind) != "message" { - return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind)) - } - - channel, err := readTestRESPBulkString(r) - if err != nil { - return "", nil, err - } - payload, err := readTestRESPBulkString(r) - if err != nil { - return "", nil, err - } - return string(channel), payload, nil -} - func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) { t.Setenv("MANAGEMENT_PASSWORD", "") redisqueue.SetEnabled(false) @@ -296,13 +103,19 @@ func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) { t.Fatalf("failed to write RESP command: %v", errWrite) } + if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil { + t.Fatalf("failed to read disabled RESP error: %v", err) + } else if msg != "ERR RESP AUTH disabled; use mTLS" { + t.Fatalf("unexpected disabled RESP error: %q", msg) + } + buf := make([]byte, 1) _, errRead := conn.Read(buf) if errRead == nil { - t.Fatalf("expected connection to be closed when management is disabled") + t.Fatalf("expected connection to be closed after disabled RESP error") } if ne, ok := errRead.(net.Error); ok && ne.Timeout() { - t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead) + t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead) } } @@ -333,17 +146,23 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) { _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) _ = writeTestRESPCommand(conn, "PING") + if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil { + t.Fatalf("failed to read disabled RESP error: %v", err) + } else if msg != "ERR RESP AUTH disabled; use mTLS" { + t.Fatalf("unexpected disabled RESP error: %q", msg) + } + buf := make([]byte, 1) _, errRead := conn.Read(buf) if errRead == nil { - t.Fatalf("expected connection to be closed when home mode is enabled") + t.Fatalf("expected connection to be closed after disabled RESP error") } if ne, ok := errRead.(net.Error); ok && ne.Timeout() { - t.Fatalf("expected connection to be closed when home mode is enabled, got timeout: %v", errRead) + t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead) } } -func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { +func TestRedisProtocol_AUTH_DisabledAndClosesConnection(t *testing.T) { const managementPassword = "test-management-password" t.Setenv("MANAGEMENT_PASSWORD", managementPassword) @@ -368,369 +187,21 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) - if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil { - t.Fatalf("failed to write AUTH command: %v", errWrite) - } - if msg, err := readTestRESPError(reader); err != nil { - t.Fatalf("failed to read AUTH error: %v", err) - } else if msg != "ERR invalid management key" { - t.Fatalf("unexpected AUTH error: %q", msg) - } - - if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write LPOP command: %v", errWrite) - } - if msg, err := readTestRESPError(reader); err != nil { - t.Fatalf("failed to read LPOP NOAUTH error: %v", err) - } else if msg != "NOAUTH Authentication required." { - t.Fatalf("unexpected LPOP NOAUTH error: %q", msg) - } - if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil { t.Fatalf("failed to write AUTH command: %v", errWrite) } - if msg, err := readTestRESPSimpleString(reader); err != nil { - t.Fatalf("failed to read AUTH response: %v", err) - } else if msg != "OK" { - t.Fatalf("unexpected AUTH response: %q", msg) + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read disabled AUTH error: %v", err) + } else if msg != "ERR RESP AUTH disabled; use mTLS" { + t.Fatalf("unexpected disabled AUTH error: %q", msg) } - if !redisqueue.Enabled() { - t.Fatalf("expected redisqueue to be enabled") + buf := make([]byte, 1) + _, errRead := conn.Read(buf) + if errRead == nil { + t.Fatalf("expected connection to be closed after disabled AUTH error") } - redisqueue.Enqueue([]byte("a")) - redisqueue.Enqueue([]byte("b")) - redisqueue.Enqueue([]byte("c")) - - if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write RPOP command: %v", errWrite) - } - if item, err := readTestRESPBulkString(reader); err != nil { - t.Fatalf("failed to read RPOP response: %v", err) - } else if string(item) != "a" { - t.Fatalf("unexpected RPOP item: %q", string(item)) - } - - if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write LPOP command: %v", errWrite) - } - if item, err := readTestRESPBulkString(reader); err != nil { - t.Fatalf("failed to read LPOP response: %v", err) - } else if string(item) != "b" { - t.Fatalf("unexpected LPOP item: %q", string(item)) - } - - if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil { - t.Fatalf("failed to write RPOP count command: %v", errWrite) - } - items, errItems := readRESPArrayOfBulkStrings(reader) - if errItems != nil { - t.Fatalf("failed to read RPOP count response: %v", errItems) - } - if len(items) != 1 || string(items[0]) != "c" { - t.Fatalf("unexpected RPOP count items: %#v", items) - } - - if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write LPOP empty command: %v", errWrite) - } - item, errItem := readTestRESPBulkString(reader) - if errItem != nil { - t.Fatalf("failed to read LPOP empty response: %v", errItem) - } - if item != nil { - t.Fatalf("expected nil bulk string for empty queue, got %q", string(item)) - } - - if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil { - t.Fatalf("failed to write RPOP empty count command: %v", errWrite) - } - emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader) - if errEmpty != nil { - t.Fatalf("failed to read RPOP empty count response: %v", errEmpty) - } - if len(emptyItems) != 0 { - t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems) - } -} - -func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) { - const managementPassword = "test-management-password" - - t.Setenv("MANAGEMENT_PASSWORD", managementPassword) - redisqueue.SetEnabled(false) - t.Cleanup(func() { redisqueue.SetEnabled(false) }) - - server := newTestServer(t) - if !server.managementRoutesEnabled.Load() { - t.Fatalf("expected managementRoutesEnabled to be true") - } - - addr, stop := startRedisMuxListener(t, server) - t.Cleanup(stop) - - firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second) - if errDialFirst != nil { - t.Fatalf("failed to dial first redis listener: %v", errDialFirst) - } - t.Cleanup(func() { _ = firstConn.Close() }) - firstReader := bufio.NewReader(firstConn) - _ = firstConn.SetDeadline(time.Now().Add(5 * time.Second)) - - if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil { - t.Fatalf("failed to write first AUTH command: %v", errWrite) - } - if msg, err := readTestRESPSimpleString(firstReader); err != nil { - t.Fatalf("failed to read first AUTH response: %v", err) - } else if msg != "OK" { - t.Fatalf("unexpected first AUTH response: %q", msg) - } - if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil { - t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite) - } - if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil { - t.Fatalf("failed to read first SUBSCRIBE response: %v", err) - } else if channel != "usage" || count != 1 { - t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count) - } - - secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second) - if errDialSecond != nil { - t.Fatalf("failed to dial second redis listener: %v", errDialSecond) - } - t.Cleanup(func() { _ = secondConn.Close() }) - secondReader := bufio.NewReader(secondConn) - _ = secondConn.SetDeadline(time.Now().Add(5 * time.Second)) - - if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil { - t.Fatalf("failed to write second AUTH command: %v", errWrite) - } - if msg, err := readTestRESPSimpleString(secondReader); err != nil { - t.Fatalf("failed to read second AUTH response: %v", err) - } else if msg != "OK" { - t.Fatalf("unexpected second AUTH response: %q", msg) - } - if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil { - t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite) - } - if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil { - t.Fatalf("failed to read second SUBSCRIBE response: %v", err) - } else if channel != "usage" || count != 1 { - t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count) - } - - redisqueue.Enqueue([]byte(`{"id":1}`)) - - if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil { - t.Fatalf("failed to read first pubsub message: %v", err) - } else if channel != "usage" || string(payload) != `{"id":1}` { - t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload)) - } - if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil { - t.Fatalf("failed to read second pubsub message: %v", err) - } else if channel != "usage" || string(payload) != `{"id":1}` { - t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload)) - } - - popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second) - if errDialPop != nil { - t.Fatalf("failed to dial pop redis listener: %v", errDialPop) - } - t.Cleanup(func() { _ = popConn.Close() }) - popReader := bufio.NewReader(popConn) - _ = popConn.SetDeadline(time.Now().Add(5 * time.Second)) - - if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil { - t.Fatalf("failed to write pop AUTH command: %v", errWrite) - } - if msg, err := readTestRESPSimpleString(popReader); err != nil { - t.Fatalf("failed to read pop AUTH response: %v", err) - } else if msg != "OK" { - t.Fatalf("unexpected pop AUTH response: %q", msg) - } - if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil { - t.Fatalf("failed to write pop LPOP command: %v", errWrite) - } - item, errItem := readTestRESPBulkString(popReader) - if errItem != nil { - t.Fatalf("failed to read pop LPOP response: %v", errItem) - } - if item != nil { - t.Fatalf("expected subscribed usage to skip queue, got %q", string(item)) - } - - managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil) - managementReq.Header.Set("Authorization", "Bearer "+managementPassword) - managementRR := httptest.NewRecorder() - server.engine.ServeHTTP(managementRR, managementReq) - if managementRR.Code != http.StatusOK { - t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String()) - } - var managementPayload []json.RawMessage - if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil { - t.Fatalf("unmarshal management usage response: %v", errUnmarshal) - } - if len(managementPayload) != 0 { - t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String()) - } -} - -func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) { - const managementPassword = "test-management-password" - - t.Setenv("MANAGEMENT_PASSWORD", managementPassword) - redisqueue.SetEnabled(false) - t.Cleanup(func() { redisqueue.SetEnabled(false) }) - - server := newTestServer(t) - if !server.managementRoutesEnabled.Load() { - t.Fatalf("expected managementRoutesEnabled to be true") - } - - clientConn, serverConn := net.Pipe() - t.Cleanup(func() { _ = clientConn.Close() }) - t.Cleanup(func() { _ = serverConn.Close() }) - - fakeRemote := &net.TCPAddr{ - IP: net.ParseIP("1.2.3.4"), - Port: 1234, - } - wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote} - - go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn)) - - reader := bufio.NewReader(clientConn) - _ = clientConn.SetDeadline(time.Now().Add(5 * time.Second)) - - for i := 0; i < 5; i++ { - if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write LPOP command: %v", errWrite) - } - if msg, err := readTestRESPError(reader); err != nil { - t.Fatalf("failed to read LPOP NOAUTH error: %v", err) - } else if msg != "NOAUTH Authentication required." { - t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg) - } - } - - if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil { - t.Fatalf("failed to write LPOP command after failures: %v", errWrite) - } - msg, err := readTestRESPError(reader) - if err != nil { - t.Fatalf("failed to read LPOP banned error: %v", err) - } - if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { - t.Fatalf("unexpected LPOP banned error: %q", msg) - } -} - -func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) { - const managementPassword = "test-management-password" - - t.Setenv("MANAGEMENT_PASSWORD", managementPassword) - redisqueue.SetEnabled(false) - t.Cleanup(func() { redisqueue.SetEnabled(false) }) - - server := newTestServer(t) - if !server.managementRoutesEnabled.Load() { - t.Fatalf("expected managementRoutesEnabled to be true") - } - - clientConn, serverConn := net.Pipe() - t.Cleanup(func() { _ = clientConn.Close() }) - t.Cleanup(func() { _ = serverConn.Close() }) - - fakeRemote := &net.TCPAddr{ - IP: net.ParseIP("1.2.3.4"), - Port: 1234, - } - wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote} - - go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn)) - - reader := bufio.NewReader(clientConn) - _ = clientConn.SetDeadline(time.Now().Add(5 * time.Second)) - - for i := 0; i < 5; i++ { - if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil { - t.Fatalf("failed to write AUTH command: %v", errWrite) - } - if msg, err := readTestRESPError(reader); err != nil { - t.Fatalf("failed to read AUTH error: %v", err) - } else if msg != "ERR invalid management key" { - t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg) - } - } - - for i := 0; i < 2; i++ { - if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil { - t.Fatalf("failed to write AUTH command after failures: %v", errWrite) - } - msg, err := readTestRESPError(reader) - if err != nil { - t.Fatalf("failed to read AUTH banned error: %v", err) - } - if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { - t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg) - } - } - - if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil { - t.Fatalf("failed to write AUTH command with correct password: %v", errWrite) - } - msg, err := readTestRESPError(reader) - if err != nil { - t.Fatalf("failed to read AUTH banned error for correct password: %v", err) - } - if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { - t.Fatalf("unexpected AUTH banned error for correct password: %q", msg) - } -} - -func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) { - const managementPassword = "test-management-password" - - t.Setenv("MANAGEMENT_PASSWORD", managementPassword) - redisqueue.SetEnabled(false) - t.Cleanup(func() { redisqueue.SetEnabled(false) }) - - server := newTestServer(t) - if !server.managementRoutesEnabled.Load() { - t.Fatalf("expected managementRoutesEnabled to be true") - } - - addr, stop := startRedisMuxListener(t, server) - t.Cleanup(stop) - - conn, errDial := net.DialTimeout("tcp", addr, time.Second) - if errDial != nil { - t.Fatalf("failed to dial redis listener: %v", errDial) - } - t.Cleanup(func() { _ = conn.Close() }) - - reader := bufio.NewReader(conn) - _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) - - for i := 0; i < 5; i++ { - if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil { - t.Fatalf("failed to write AUTH command: %v", errWrite) - } - if msg, err := readTestRESPError(reader); err != nil { - t.Fatalf("failed to read AUTH error: %v", err) - } else if msg != "ERR invalid management key" { - t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg) - } - } - - if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil { - t.Fatalf("failed to write AUTH command with correct password: %v", errWrite) - } - msg, err := readTestRESPError(reader) - if err != nil { - t.Fatalf("failed to read AUTH banned error for correct password: %v", err) - } - if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { - t.Fatalf("unexpected AUTH banned error for correct password: %q", msg) + if ne, ok := errRead.(net.Error); ok && ne.Timeout() { + t.Fatalf("expected connection to be closed after disabled AUTH error, got timeout: %v", errRead) } } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index c853a711a..e503fe71b 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" diff --git a/internal/config/config.go b/internal/config/config.go index ddc6bd535..dd0b05c72 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,8 +37,8 @@ type Config struct { // TLS config controls HTTPS server settings. TLS TLSConfig `yaml:"tls" json:"tls"` - // Home config enables the Redis-based control plane integration. - Home HomeConfig `yaml:"home" json:"-"` + // Home config is runtime-only and is populated from -home-jwt. + Home HomeConfig `yaml:"-" json:"-"` // RemoteManagement nests management-related options under 'remote-management'. RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"` @@ -69,8 +69,8 @@ type Config struct { // UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded. UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"` - // RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items - // are retained in memory for the Redis RESP interface (LPOP/RPOP). + // RedisUsageQueueRetentionSeconds controls how long usage queue items are retained + // in memory for Management API consumers. // Default: 60. Max: 3600. RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"` diff --git a/internal/config/home.go b/internal/config/home.go index 8cf323b6d..07ac1fed6 100644 --- a/internal/config/home.go +++ b/internal/config/home.go @@ -1,11 +1,10 @@ package config -// HomeConfig configures the optional "home" control plane integration over Redis protocol. +// HomeConfig stores runtime-only Home control plane settings from -home-jwt. type HomeConfig struct { Enabled bool `yaml:"enabled" json:"enabled"` Host string `yaml:"host" json:"-"` Port int `yaml:"port" json:"-"` - Password string `yaml:"password" json:"-"` DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"` TLS HomeTLSConfig `yaml:"tls" json:"-"` } diff --git a/internal/config/home_test.go b/internal/config/home_test.go index ac26d2cbf..850f3b72e 100644 --- a/internal/config/home_test.go +++ b/internal/config/home_test.go @@ -2,13 +2,12 @@ package config import "testing" -func TestParseConfigBytesHomeTLS(t *testing.T) { +func TestParseConfigBytesIgnoresHomeConfig(t *testing.T) { cfg, err := ParseConfigBytes([]byte(` home: enabled: true host: home.example.com port: 444 - password: secret disable-cluster-discovery: true tls: enable: true @@ -20,31 +19,28 @@ home: t.Fatalf("ParseConfigBytes() error = %v", err) } - if !cfg.Home.Enabled { - t.Fatal("Home.Enabled = false, want true") + if cfg.Home.Enabled { + t.Fatal("Home.Enabled = true, want false") } - if cfg.Home.Host != "home.example.com" { - t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host) + if cfg.Home.Host != "" { + t.Fatalf("Home.Host = %q, want empty", cfg.Home.Host) } - if cfg.Home.Port != 444 { - t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port) + if cfg.Home.Port != 0 { + t.Fatalf("Home.Port = %d, want 0", cfg.Home.Port) } - if cfg.Home.Password != "secret" { - t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password) + if cfg.Home.DisableClusterDiscovery { + t.Fatal("Home.DisableClusterDiscovery = true, want false") } - if !cfg.Home.DisableClusterDiscovery { - t.Fatal("Home.DisableClusterDiscovery = false, want true") + if cfg.Home.TLS.Enable { + t.Fatal("Home.TLS.Enable = true, want false") } - if !cfg.Home.TLS.Enable { - t.Fatal("Home.TLS.Enable = false, want true") + if cfg.Home.TLS.ServerName != "" { + t.Fatalf("Home.TLS.ServerName = %q, want empty", cfg.Home.TLS.ServerName) } - if cfg.Home.TLS.ServerName != "home.example.com" { - t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName) + if cfg.Home.TLS.CACert != "" { + t.Fatalf("Home.TLS.CACert = %q, want empty", cfg.Home.TLS.CACert) } - if cfg.Home.TLS.CACert != "C:/certs/ca.pem" { - t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert) - } - if !cfg.Home.TLS.InsecureSkipVerify { - t.Fatal("Home.TLS.InsecureSkipVerify = false, want true") + if cfg.Home.TLS.InsecureSkipVerify { + t.Fatal("Home.TLS.InsecureSkipVerify = true, want false") } } diff --git a/internal/home/client.go b/internal/home/client.go index 2c81187e4..0357529e6 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -180,7 +180,6 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) { } return &redis.Options{ Addr: addr, - Password: c.homeCfg.Password, TLSConfig: tlsConfig, DialTimeout: homeRedisOperationTimeout, ReadTimeout: homeRedisOperationTimeout, diff --git a/internal/home/client_test.go b/internal/home/client_test.go index b3a1ae583..b0415d89b 100644 --- a/internal/home/client_test.go +++ b/internal/home/client_test.go @@ -37,10 +37,9 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) { func TestRedisOptionsHomeTLSDisabled(t *testing.T) { client := New(config.HomeConfig{ - Enabled: true, - Host: "127.0.0.1", - Port: 6379, - Password: "secret", + Enabled: true, + Host: "127.0.0.1", + Port: 6379, }) client.mu.Lock() @@ -53,8 +52,8 @@ func TestRedisOptionsHomeTLSDisabled(t *testing.T) { if options.TLSConfig != nil { t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig) } - if options.Password != "secret" { - t.Fatalf("Password = %q, want secret", options.Password) + if options.Password != "" { + t.Fatalf("Password = %q, want empty", options.Password) } }