From 48104abf51037159dd7267b3b9d82ffb6bf14fcf Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 16 May 2026 19:57:19 +0800 Subject: [PATCH 1/2] feat(home): implement home control plane integration with Redis and TLS support --- cmd/server/home_flag.go | 124 +++++++++++++++++++++++++++++++++++ cmd/server/home_flag_test.go | 66 +++++++++++++++++++ cmd/server/main.go | 27 ++------ config.example.yaml | 10 +++ internal/config/home.go | 17 +++-- internal/config/home_test.go | 46 +++++++++++++ internal/home/client.go | 82 ++++++++++++++++++++--- internal/home/client_test.go | 85 ++++++++++++++++++++++++ 8 files changed, 422 insertions(+), 35 deletions(-) create mode 100644 cmd/server/home_flag.go create mode 100644 cmd/server/home_flag_test.go create mode 100644 internal/config/home_test.go diff --git a/cmd/server/home_flag.go b/cmd/server/home_flag.go new file mode 100644 index 000000000..2d79ef833 --- /dev/null +++ b/cmd/server/home_flag.go @@ -0,0 +1,124 @@ +package main + +import ( + "fmt" + "net" + "net/url" + "strconv" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" +) + +func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) { + rawAddr = strings.TrimSpace(rawAddr) + if rawAddr == "" { + return config.HomeConfig{}, fmt.Errorf("address is empty") + } + + if strings.Contains(rawAddr, "://") { + return parseHomeURLConfig(rawAddr, password) + } + + host, portStr, errSplit := net.SplitHostPort(rawAddr) + if errSplit != nil { + return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit) + } + + host = strings.TrimSpace(host) + if host == "" { + return config.HomeConfig{}, fmt.Errorf("host is empty") + } + + port, errPort := parseHomePort(portStr) + if errPort != nil { + return config.HomeConfig{}, errPort + } + + return config.HomeConfig{ + Enabled: true, + Host: host, + Port: port, + Password: password, + }, nil +} + +func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) { + parsed, errParse := url.Parse(rawAddr) + if errParse != nil { + return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse) + } + + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "redis" && scheme != "rediss" { + return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme) + } + + host := strings.TrimSpace(parsed.Hostname()) + if host == "" { + return config.HomeConfig{}, fmt.Errorf("host is empty") + } + + port, errPort := parseHomePort(parsed.Port()) + if errPort != nil { + return config.HomeConfig{}, errPort + } + + if password == "" && parsed.User != nil { + if urlPassword, ok := parsed.User.Password(); ok { + password = urlPassword + } + } + + homeCfg := config.HomeConfig{ + Enabled: true, + Host: host, + Port: port, + Password: password, + } + + if scheme == "rediss" { + homeCfg.TLS.Enable = true + query := parsed.Query() + homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name")) + homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify") + homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert")) + } + + return homeCfg, nil +} + +func parseHomePort(rawPort string) (int, error) { + rawPort = strings.TrimSpace(rawPort) + if rawPort == "" { + return 0, fmt.Errorf("port is empty") + } + + port, errPort := strconv.Atoi(rawPort) + if errPort != nil || port <= 0 || port > 65535 { + return 0, fmt.Errorf("invalid port %q", rawPort) + } + + return port, nil +} + +func firstHomeQueryValue(values url.Values, keys ...string) string { + for _, key := range keys { + if value := values.Get(key); value != "" { + return value + } + } + return "" +} + +func parseHomeBoolQuery(values url.Values, keys ...string) bool { + for _, key := range keys { + value := strings.TrimSpace(values.Get(key)) + if value == "" { + continue + } + parsed, errParse := strconv.ParseBool(value) + return errParse == nil && parsed + } + return false +} diff --git a/cmd/server/home_flag_test.go b/cmd/server/home_flag_test.go new file mode 100644 index 000000000..9947f9402 --- /dev/null +++ b/cmd/server/home_flag_test.go @@ -0,0 +1,66 @@ +package main + +import "testing" + +func TestParseHomeFlagConfigHostPort(t *testing.T) { + cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret") + if err != nil { + t.Fatalf("parseHomeFlagConfig() error = %v", err) + } + + if !cfg.Enabled { + t.Fatal("Enabled = false, want true") + } + if cfg.Host != "home.example.com" { + t.Fatalf("Host = %q, want home.example.com", cfg.Host) + } + if cfg.Port != 8327 { + t.Fatalf("Port = %d, want 8327", cfg.Port) + } + if cfg.Password != "secret" { + t.Fatalf("Password = %q, want secret", cfg.Password) + } + if cfg.TLS.Enable { + t.Fatal("TLS.Enable = true, want false") + } +} + +func TestParseHomeFlagConfigRediss(t *testing.T) { + cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "") + if err != nil { + t.Fatalf("parseHomeFlagConfig() error = %v", err) + } + + if cfg.Host != "home.example.com" { + t.Fatalf("Host = %q, want home.example.com", cfg.Host) + } + if cfg.Port != 444 { + t.Fatalf("Port = %d, want 444", cfg.Port) + } + if cfg.Password != "url-secret" { + t.Fatalf("Password = %q, want url-secret", cfg.Password) + } + if !cfg.TLS.Enable { + t.Fatal("TLS.Enable = false, want true") + } + if cfg.TLS.ServerName != "home.example.com" { + t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName) + } + if !cfg.TLS.InsecureSkipVerify { + t.Fatal("TLS.InsecureSkipVerify = false, want true") + } + if cfg.TLS.CACert != "C:/certs/ca.pem" { + t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert) + } +} + +func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) { + cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret") + if err != nil { + t.Fatalf("parseHomeFlagConfig() error = %v", err) + } + + if cfg.Password != "flag-secret" { + t.Fatalf("Password = %q, want flag-secret", cfg.Password) + } +} diff --git a/cmd/server/main.go b/cmd/server/main.go index 1ef830066..70f7c9531 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,11 +10,9 @@ import ( "fmt" "io" "io/fs" - "net" "net/url" "os" "path/filepath" - "strconv" "strings" "time" @@ -93,7 +91,7 @@ func main() { flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)") flag.StringVar(&password, "password", "", "") - flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port format (loads config from home and skips local config file)") + flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)") flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") @@ -247,28 +245,11 @@ func main() { if strings.TrimSpace(homeAddr) != "" { configLoadedFromHome = true trimmedHomePassword := strings.TrimSpace(homePassword) - host, portStr, errSplit := net.SplitHostPort(strings.TrimSpace(homeAddr)) - if errSplit != nil { - log.Errorf("invalid -home address %q (expected host:port): %v", homeAddr, errSplit) + homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword) + if errHomeCfg != nil { + log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg) return } - host = strings.TrimSpace(host) - if host == "" { - log.Errorf("invalid -home address %q: host is empty", homeAddr) - return - } - port, errPort := strconv.Atoi(strings.TrimSpace(portStr)) - if errPort != nil || port <= 0 { - log.Errorf("invalid -home address %q: invalid port %q", homeAddr, portStr) - return - } - - homeCfg := config.HomeConfig{ - Enabled: true, - Host: host, - Port: port, - Password: trimmedHomePassword, - } homeClient := home.New(homeCfg) defer homeClient.Close() diff --git a/config.example.yaml b/config.example.yaml index 886d775a5..d9a4fc047 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -17,6 +17,16 @@ home: host: "127.0.0.1" port: 6379 password: "" + # Optional TLS for the outbound Redis connection to the home control plane. + # Enable this when connecting through rediss:// or an SSL stream proxy. + tls: + enable: false + # Optional SNI/certificate name override. Leave empty to use the configured home host. + server-name: "" + # Trust a private CA bundle in addition to system roots. + ca-cert: "" + # Only for testing self-signed endpoints; disables certificate verification. + insecure-skip-verify: false # Management API settings remote-management: diff --git a/internal/config/home.go b/internal/config/home.go index 03c917323..ffcdd4b7a 100644 --- a/internal/config/home.go +++ b/internal/config/home.go @@ -2,8 +2,17 @@ package config // HomeConfig configures the optional "home" control plane integration over Redis protocol. type HomeConfig struct { - Enabled bool `yaml:"enabled" json:"enabled"` - Host string `yaml:"host" json:"-"` - Port int `yaml:"port" json:"-"` - Password string `yaml:"password" json:"-"` + Enabled bool `yaml:"enabled" json:"enabled"` + Host string `yaml:"host" json:"-"` + Port int `yaml:"port" json:"-"` + Password string `yaml:"password" json:"-"` + TLS HomeTLSConfig `yaml:"tls" json:"-"` +} + +// HomeTLSConfig configures client-side TLS for the home Redis connection. +type HomeTLSConfig struct { + Enable bool `yaml:"enable" json:"-"` + ServerName string `yaml:"server-name" json:"-"` + InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"` + CACert string `yaml:"ca-cert" json:"-"` } diff --git a/internal/config/home_test.go b/internal/config/home_test.go new file mode 100644 index 000000000..2a5d64fb3 --- /dev/null +++ b/internal/config/home_test.go @@ -0,0 +1,46 @@ +package config + +import "testing" + +func TestParseConfigBytesHomeTLS(t *testing.T) { + cfg, err := ParseConfigBytes([]byte(` +home: + enabled: true + host: home.example.com + port: 444 + password: secret + tls: + enable: true + server-name: home.example.com + ca-cert: C:/certs/ca.pem + insecure-skip-verify: true +`)) + if err != nil { + t.Fatalf("ParseConfigBytes() error = %v", err) + } + + if !cfg.Home.Enabled { + t.Fatal("Home.Enabled = false, want true") + } + if cfg.Home.Host != "home.example.com" { + t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host) + } + if cfg.Home.Port != 444 { + t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port) + } + if cfg.Home.Password != "secret" { + t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password) + } + if !cfg.Home.TLS.Enable { + t.Fatal("Home.TLS.Enable = false, want true") + } + if cfg.Home.TLS.ServerName != "home.example.com" { + t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName) + } + if cfg.Home.TLS.CACert != "C:/certs/ca.pem" { + t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert) + } + if !cfg.Home.TLS.InsecureSkipVerify { + t.Fatal("Home.TLS.InsecureSkipVerify = false, want true") + } +} diff --git a/internal/home/client.go b/internal/home/client.go index 9e7a9056f..5d0c96cea 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -2,11 +2,14 @@ package home import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -151,20 +154,83 @@ func (c *Client) ensureClients() error { } if c.cmd == nil { - c.cmd = redis.NewClient(&redis.Options{ - Addr: addr, - Password: c.homeCfg.Password, - }) + options, errOptions := c.redisOptionsLocked(addr) + if errOptions != nil { + return errOptions + } + c.cmd = redis.NewClient(options) } if c.sub == nil { - c.sub = redis.NewClient(&redis.Options{ - Addr: addr, - Password: c.homeCfg.Password, - }) + options, errOptions := c.redisOptionsLocked(addr) + if errOptions != nil { + return errOptions + } + c.sub = redis.NewClient(options) } return nil } +func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) { + tlsConfig, errTLS := c.homeTLSConfigLocked() + if errTLS != nil { + return nil, errTLS + } + return &redis.Options{ + Addr: addr, + Password: c.homeCfg.Password, + TLSConfig: tlsConfig, + }, nil +} + +func (c *Client) homeTLSConfigLocked() (*tls.Config, error) { + serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName) + if serverName == "" { + serverName = strings.TrimSpace(c.seedHost) + } + if serverName == "" { + serverName = strings.TrimSpace(c.homeCfg.Host) + } + return newHomeTLSConfig(c.homeCfg.TLS, serverName) +} + +func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) { + if !cfg.Enable { + return nil, nil + } + + serverName := strings.TrimSpace(cfg.ServerName) + if serverName == "" { + serverName = strings.TrimSpace(fallbackServerName) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: serverName, + InsecureSkipVerify: cfg.InsecureSkipVerify, + } + + caCertPath := strings.TrimSpace(cfg.CACert) + if caCertPath == "" { + return tlsConfig, nil + } + + caCertPEM, errRead := os.ReadFile(caCertPath) + if errRead != nil { + return nil, fmt.Errorf("home tls: read ca-cert: %w", errRead) + } + + certPool, errPool := x509.SystemCertPool() + if errPool != nil || certPool == nil { + certPool = x509.NewCertPool() + } + if !certPool.AppendCertsFromPEM(caCertPEM) { + return nil, fmt.Errorf("home tls: ca-cert contains no PEM certificates") + } + tlsConfig.RootCAs = certPool + + return tlsConfig, nil +} + func (c *Client) commandClient() (*redis.Client, error) { if errEnsure := c.ensureClients(); errEnsure != nil { return nil, errEnsure diff --git a/internal/home/client_test.go b/internal/home/client_test.go index 625e77bca..65148f676 100644 --- a/internal/home/client_test.go +++ b/internal/home/client_test.go @@ -1,9 +1,12 @@ package home import ( + "crypto/tls" "encoding/json" "net/http" "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" ) func TestAuthDispatchRequestIncludesCount(t *testing.T) { @@ -30,3 +33,85 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) { t.Fatalf("count = %d, want 1", req.Count) } } + +func TestRedisOptionsHomeTLSDisabled(t *testing.T) { + client := New(config.HomeConfig{ + Enabled: true, + Host: "127.0.0.1", + Port: 6379, + Password: "secret", + }) + + client.mu.Lock() + options, err := client.redisOptionsLocked("127.0.0.1:6379") + client.mu.Unlock() + if err != nil { + t.Fatalf("redisOptionsLocked() error = %v", err) + } + + if options.TLSConfig != nil { + t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig) + } + if options.Password != "secret" { + t.Fatalf("Password = %q, want secret", options.Password) + } +} + +func TestRedisOptionsHomeTLSEnabledUsesSeedHostAsServerName(t *testing.T) { + client := New(config.HomeConfig{ + Enabled: true, + Host: "home.example.com", + Port: 444, + TLS: config.HomeTLSConfig{ + Enable: true, + }, + }) + client.homeCfg.Host = "127.0.0.1" + + client.mu.Lock() + options, err := client.redisOptionsLocked("127.0.0.1:444") + client.mu.Unlock() + if err != nil { + t.Fatalf("redisOptionsLocked() error = %v", err) + } + + if options.TLSConfig == nil { + t.Fatal("TLSConfig is nil") + } + if options.TLSConfig.ServerName != "home.example.com" { + t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName) + } + if options.TLSConfig.MinVersion != tls.VersionTLS12 { + t.Fatalf("MinVersion = %d, want TLS 1.2", options.TLSConfig.MinVersion) + } +} + +func TestRedisOptionsHomeTLSEnabledUsesExplicitServerName(t *testing.T) { + client := New(config.HomeConfig{ + Enabled: true, + Host: "127.0.0.1", + Port: 444, + TLS: config.HomeTLSConfig{ + Enable: true, + ServerName: "home.example.com", + InsecureSkipVerify: true, + }, + }) + + client.mu.Lock() + options, err := client.redisOptionsLocked("127.0.0.1:444") + client.mu.Unlock() + if err != nil { + t.Fatalf("redisOptionsLocked() error = %v", err) + } + + if options.TLSConfig == nil { + t.Fatal("TLSConfig is nil") + } + if options.TLSConfig.ServerName != "home.example.com" { + t.Fatalf("ServerName = %q, want home.example.com", options.TLSConfig.ServerName) + } + if !options.TLSConfig.InsecureSkipVerify { + t.Fatal("InsecureSkipVerify = false, want true") + } +} From 644d5ea618fd4bdc57bf087622ecd1b6f6f08b39 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 16 May 2026 20:25:29 +0800 Subject: [PATCH 2/2] feat(home): add support for disabling cluster discovery in Redis configuration --- cmd/server/home_flag.go | 3 ++- cmd/server/home_flag_test.go | 11 ++++++++++ cmd/server/main.go | 5 +++++ config.example.yaml | 3 +++ internal/config/home.go | 11 +++++----- internal/config/home_test.go | 4 ++++ internal/home/client.go | 23 ++++++++++++++++++++ internal/home/client_test.go | 42 ++++++++++++++++++++++++++++++++++++ 8 files changed, 96 insertions(+), 6 deletions(-) diff --git a/cmd/server/home_flag.go b/cmd/server/home_flag.go index 2d79ef833..ade94fbf3 100644 --- a/cmd/server/home_flag.go +++ b/cmd/server/home_flag.go @@ -76,10 +76,11 @@ func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, err Port: port, Password: password, } + query := parsed.Query() + homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery") if scheme == "rediss" { homeCfg.TLS.Enable = true - query := parsed.Query() homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name")) homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify") homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert")) diff --git a/cmd/server/home_flag_test.go b/cmd/server/home_flag_test.go index 9947f9402..e98d85f17 100644 --- a/cmd/server/home_flag_test.go +++ b/cmd/server/home_flag_test.go @@ -64,3 +64,14 @@ func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) { t.Fatalf("Password = %q, want flag-secret", cfg.Password) } } + +func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) { + cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "") + if err != nil { + t.Fatalf("parseHomeFlagConfig() error = %v", err) + } + + if !cfg.DisableClusterDiscovery { + t.Fatal("DisableClusterDiscovery = false, want true") + } +} diff --git a/cmd/server/main.go b/cmd/server/main.go index 70f7c9531..7da5b087a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -73,6 +73,7 @@ func main() { var password string var homeAddr string var homePassword string + var homeDisableClusterDiscovery bool var tuiMode bool var standalone bool var localModel bool @@ -93,6 +94,7 @@ func main() { flag.StringVar(&password, "password", "", "") flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)") flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)") + flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching") @@ -250,6 +252,9 @@ func main() { log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg) return } + if homeDisableClusterDiscovery { + homeCfg.DisableClusterDiscovery = true + } homeClient := home.New(homeCfg) defer homeClient.Close() diff --git a/config.example.yaml b/config.example.yaml index d9a4fc047..d49c378cb 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -17,6 +17,9 @@ home: host: "127.0.0.1" port: 6379 password: "" + # Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries. + # Useful when Home is behind NAT, Docker networking, or a reverse proxy. + disable-cluster-discovery: false # Optional TLS for the outbound Redis connection to the home control plane. # Enable this when connecting through rediss:// or an SSL stream proxy. tls: diff --git a/internal/config/home.go b/internal/config/home.go index ffcdd4b7a..8e7945b40 100644 --- a/internal/config/home.go +++ b/internal/config/home.go @@ -2,11 +2,12 @@ package config // HomeConfig configures the optional "home" control plane integration over Redis protocol. type HomeConfig struct { - Enabled bool `yaml:"enabled" json:"enabled"` - Host string `yaml:"host" json:"-"` - Port int `yaml:"port" json:"-"` - Password string `yaml:"password" json:"-"` - TLS HomeTLSConfig `yaml:"tls" json:"-"` + Enabled bool `yaml:"enabled" json:"enabled"` + Host string `yaml:"host" json:"-"` + Port int `yaml:"port" json:"-"` + Password string `yaml:"password" json:"-"` + DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"` + TLS HomeTLSConfig `yaml:"tls" json:"-"` } // HomeTLSConfig configures client-side TLS for the home Redis connection. diff --git a/internal/config/home_test.go b/internal/config/home_test.go index 2a5d64fb3..ac26d2cbf 100644 --- a/internal/config/home_test.go +++ b/internal/config/home_test.go @@ -9,6 +9,7 @@ home: host: home.example.com port: 444 password: secret + disable-cluster-discovery: true tls: enable: true server-name: home.example.com @@ -31,6 +32,9 @@ home: if cfg.Home.Password != "secret" { t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password) } + if !cfg.Home.DisableClusterDiscovery { + t.Fatal("Home.DisableClusterDiscovery = false, want true") + } if !cfg.Home.TLS.Enable { t.Fatal("Home.TLS.Enable = false, want true") } diff --git a/internal/home/client.go b/internal/home/client.go index 5d0c96cea..3edd3135a 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -265,7 +265,23 @@ func (c *Client) Ping(ctx context.Context) error { return cmd.Ping(ctx).Err() } +func (c *Client) clusterDiscoveryEnabled() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return c.clusterDiscoveryEnabledLocked() +} + +func (c *Client) clusterDiscoveryEnabledLocked() bool { + return !c.homeCfg.DisableClusterDiscovery +} + func (c *Client) refreshBestClusterNode(ctx context.Context) { + if !c.clusterDiscoveryEnabled() { + return + } switched, errRefresh := c.refreshClusterNodes(ctx) if errRefresh != nil { log.Debugf("home cluster nodes unavailable: %v", errRefresh) @@ -279,6 +295,9 @@ func (c *Client) refreshBestClusterNode(ctx context.Context) { } func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) { + if !c.clusterDiscoveryEnabled() { + return false, nil + } if ctx == nil { ctx = context.Background() } @@ -353,6 +372,10 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) { c.mu.Lock() defer c.mu.Unlock() + if !c.clusterDiscoveryEnabledLocked() { + c.reconnectFailures = 0 + return false, "" + } c.reconnectFailures++ if c.reconnectFailures < homeReconnectFailoverThreshold { return false, "" diff --git a/internal/home/client_test.go b/internal/home/client_test.go index 65148f676..b3a1ae583 100644 --- a/internal/home/client_test.go +++ b/internal/home/client_test.go @@ -1,6 +1,7 @@ package home import ( + "context" "crypto/tls" "encoding/json" "net/http" @@ -115,3 +116,44 @@ func TestRedisOptionsHomeTLSEnabledUsesExplicitServerName(t *testing.T) { t.Fatal("InsecureSkipVerify = false, want true") } } + +func TestRefreshClusterNodesDisabledSkipsRedisCommand(t *testing.T) { + client := New(config.HomeConfig{ + Enabled: true, + Host: "127.0.0.1", + Port: 1, + DisableClusterDiscovery: true, + }) + + switched, err := client.refreshClusterNodes(context.Background()) + if err != nil { + t.Fatalf("refreshClusterNodes() error = %v", err) + } + if switched { + t.Fatal("refreshClusterNodes() switched = true, want false") + } + if client.cmd != nil || client.sub != nil { + t.Fatalf("redis clients were initialized when cluster discovery was disabled") + } +} + +func TestFailoverAfterReconnectFailureDisabledDoesNotSwitchToClusterNode(t *testing.T) { + client := New(config.HomeConfig{ + Enabled: true, + Host: "seed.example.com", + Port: 8327, + DisableClusterDiscovery: true, + }) + client.mu.Lock() + client.clusterNodes = []clusterNode{{IP: "other.example.com", Port: 8327}} + client.reconnectFailures = homeReconnectFailoverThreshold - 1 + client.mu.Unlock() + + switched, addr := client.failoverAfterReconnectFailure() + if switched { + t.Fatalf("failoverAfterReconnectFailure() switched to %s, want no switch", addr) + } + if got, _ := client.addr(); got != "seed.example.com:8327" { + t.Fatalf("addr() = %q, want seed.example.com:8327", got) + } +}