From bcbb94906c3c635278c662453ea56b90760d078d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 14 May 2026 00:21:31 +0800 Subject: [PATCH] feat(client): add cluster node failover and improve reconnection handling - Introduced cluster node management with `clusterNode` and `clusterNodesEnvelope` types. - Added failover handling for reconnection failures with configurable threshold (`homeReconnectFailoverThreshold`). - Implemented node switching and dynamic cluster target updates. - Enhanced Redis client management with centralized locking for concurrency safety. - Updated configuration refresh logic to prioritize the best cluster node. - Improved debug logging for reconnect failures and node switching. --- internal/home/client.go | 268 +++++++++++++++++++++++++++++++++++----- 1 file changed, 238 insertions(+), 30 deletions(-) diff --git a/internal/home/client.go b/internal/home/client.go index 23082cc69..40a191fe2 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "net/http" + "sort" "strings" + "sync" "sync/atomic" "time" @@ -22,7 +24,8 @@ const ( redisKeyUsage = "usage" redisKeyRequestLog = "request-log" - homeReconnectInterval = time.Second + homeReconnectInterval = time.Second + homeReconnectFailoverThreshold = 3 ) var ( @@ -34,23 +37,48 @@ var ( ErrModelsNotFound = errors.New("home models not found") ) +type clusterNode struct { + IP string `json:"ip"` + Port int `json:"port"` + ClientCount int `json:"client_count"` + IsMaster bool `json:"is_master"` + LastSeenAt time.Time `json:"last_seen_at"` +} + +type clusterNodesEnvelope struct { + OK bool `json:"ok"` + Nodes []clusterNode `json:"nodes"` +} + type Client struct { - homeCfg config.HomeConfig + mu sync.Mutex + + homeCfg config.HomeConfig + seedHost string + seedPort int cmd *redis.Client sub *redis.Client - heartbeatOK atomic.Bool + heartbeatOK atomic.Bool + clusterNodes []clusterNode + reconnectFailures int } func New(homeCfg config.HomeConfig) *Client { - return &Client{homeCfg: homeCfg} + return &Client{ + homeCfg: homeCfg, + seedHost: strings.TrimSpace(homeCfg.Host), + seedPort: homeCfg.Port, + } } func (c *Client) Enabled() bool { if c == nil { return false } + c.mu.Lock() + defer c.mu.Unlock() return c.homeCfg.Enabled } @@ -69,6 +97,12 @@ func (c *Client) Close() { return } c.heartbeatOK.Store(false) + c.mu.Lock() + defer c.mu.Unlock() + c.closeClientsLocked() +} + +func (c *Client) closeClientsLocked() { if c.cmd != nil { _ = c.cmd.Close() } @@ -83,6 +117,12 @@ func (c *Client) addr() (string, bool) { if c == nil { return "", false } + c.mu.Lock() + defer c.mu.Unlock() + return c.addrLocked() +} + +func (c *Client) addrLocked() (string, bool) { host := strings.TrimSpace(c.homeCfg.Host) if host == "" { return "", false @@ -100,7 +140,10 @@ func (c *Client) ensureClients() error { if !c.Enabled() { return ErrDisabled } - addr, ok := c.addr() + c.mu.Lock() + defer c.mu.Unlock() + + addr, ok := c.addrLocked() if !ok { return fmt.Errorf("home: invalid address (host=%q port=%d)", c.homeCfg.Host, c.homeCfg.Port) } @@ -120,21 +163,172 @@ func (c *Client) ensureClients() error { return nil } +func (c *Client) commandClient() (*redis.Client, error) { + if errEnsure := c.ensureClients(); errEnsure != nil { + return nil, errEnsure + } + c.mu.Lock() + cmd := c.cmd + c.mu.Unlock() + if cmd == nil { + return nil, ErrNotConnected + } + return cmd, nil +} + +func (c *Client) subscriptionClient() (*redis.Client, error) { + if errEnsure := c.ensureClients(); errEnsure != nil { + return nil, errEnsure + } + c.mu.Lock() + sub := c.sub + c.mu.Unlock() + if sub == nil { + return nil, ErrNotConnected + } + return sub, nil +} + func (c *Client) Ping(ctx context.Context) error { - if err := c.ensureClients(); err != nil { - return err + cmd, errClient := c.commandClient() + if errClient != nil { + return errClient } - if c.cmd == nil { - return ErrNotConnected + return cmd.Ping(ctx).Err() +} + +func (c *Client) refreshBestClusterNode(ctx context.Context) { + switched, errRefresh := c.refreshClusterNodes(ctx) + if errRefresh != nil { + log.Debugf("home cluster nodes unavailable: %v", errRefresh) + return } - return c.cmd.Ping(ctx).Err() + if switched { + if addr, ok := c.addr(); ok { + log.Infof("home cluster target switched to %s", addr) + } + } +} + +func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) { + if ctx == nil { + ctx = context.Background() + } + cmd, errClient := c.commandClient() + if errClient != nil { + return false, errClient + } + raw, errDo := cmd.Do(ctx, "CLUSTER", "NODES").Text() + if errDo != nil { + return false, errDo + } + + var envelope clusterNodesEnvelope + if errUnmarshal := json.Unmarshal([]byte(raw), &envelope); errUnmarshal != nil { + return false, errUnmarshal + } + nodes := normalizeClusterNodes(envelope.Nodes) + if len(nodes) == 0 { + return false, nil + } + + c.mu.Lock() + defer c.mu.Unlock() + c.clusterNodes = nodes + c.reconnectFailures = 0 + return c.switchToNodeLocked(nodes[0]), nil +} + +func normalizeClusterNodes(nodes []clusterNode) []clusterNode { + out := make([]clusterNode, 0, len(nodes)) + for _, node := range nodes { + node.IP = strings.TrimSpace(node.IP) + if node.IP == "" || node.Port <= 0 { + continue + } + if node.ClientCount < 0 { + node.ClientCount = 0 + } + out = append(out, node) + } + sort.SliceStable(out, func(i, j int) bool { + return out[i].ClientCount < out[j].ClientCount + }) + return out +} + +func (c *Client) switchToNodeLocked(node clusterNode) bool { + host := strings.TrimSpace(node.IP) + if host == "" || node.Port <= 0 { + return false + } + if strings.TrimSpace(c.homeCfg.Host) == host && c.homeCfg.Port == node.Port { + return false + } + c.homeCfg.Host = host + c.homeCfg.Port = node.Port + c.closeClientsLocked() + return true +} + +func (c *Client) markReconnectFailure(reason string) { + switched, addr := c.failoverAfterReconnectFailure() + if switched { + log.Warnf("home control center unavailable after repeated %s failures; switching to %s", reason, addr) + } +} + +func (c *Client) failoverAfterReconnectFailure() (bool, string) { + if c == nil { + return false, "" + } + c.mu.Lock() + defer c.mu.Unlock() + + c.reconnectFailures++ + if c.reconnectFailures < homeReconnectFailoverThreshold { + return false, "" + } + c.reconnectFailures = 0 + + currentHost := strings.TrimSpace(c.homeCfg.Host) + currentPort := c.homeCfg.Port + candidates := append([]clusterNode(nil), c.clusterNodes...) + if strings.TrimSpace(c.seedHost) != "" && c.seedPort > 0 { + candidates = append(candidates, clusterNode{IP: c.seedHost, Port: c.seedPort}) + } + for _, node := range candidates { + host := strings.TrimSpace(node.IP) + if host == "" || node.Port <= 0 { + continue + } + if host == currentHost && node.Port == currentPort { + continue + } + if c.switchToNodeLocked(clusterNode{IP: host, Port: node.Port}) { + addr, _ := c.addrLocked() + return true, addr + } + } + return false, "" +} + +func (c *Client) resetReconnectFailures() { + if c == nil { + return + } + c.mu.Lock() + c.reconnectFailures = 0 + c.mu.Unlock() } func (c *Client) GetConfig(ctx context.Context) ([]byte, error) { - if err := c.ensureClients(); err != nil { - return nil, err + c.refreshBestClusterNode(ctx) + cmd, errClient := c.commandClient() + if errClient != nil { + return nil, errClient } - raw, err := c.cmd.Get(ctx, redisKeyConfig).Bytes() + raw, err := cmd.Get(ctx, redisKeyConfig).Bytes() if errors.Is(err, redis.Nil) { return nil, ErrConfigNotFound } @@ -148,10 +342,11 @@ func (c *Client) GetConfig(ctx context.Context) ([]byte, error) { } func (c *Client) GetModels(ctx context.Context) ([]byte, error) { - if err := c.ensureClients(); err != nil { - return nil, err + cmd, errClient := c.commandClient() + if errClient != nil { + return nil, errClient } - raw, err := c.cmd.Get(ctx, redisKeyModels).Bytes() + raw, err := cmd.Get(ctx, redisKeyModels).Bytes() if errors.Is(err, redis.Nil) { return nil, ErrModelsNotFound } @@ -204,8 +399,9 @@ func newAuthDispatchRequest(requestedModel string, sessionID string, headers htt } func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID string, headers http.Header, count int) ([]byte, error) { - if err := c.ensureClients(); err != nil { - return nil, err + cmd, errClient := c.commandClient() + if errClient != nil { + return nil, errClient } requestedModel = strings.TrimSpace(requestedModel) if requestedModel == "" { @@ -217,7 +413,7 @@ func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID return nil, err } - raw, err := c.cmd.RPop(ctx, string(keyBytes)).Bytes() + raw, err := cmd.RPop(ctx, string(keyBytes)).Bytes() if errors.Is(err, redis.Nil) { return nil, ErrAuthNotFound } @@ -231,8 +427,9 @@ func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID } func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, error) { - if err := c.ensureClients(); err != nil { - return nil, err + cmd, errClient := c.commandClient() + if errClient != nil { + return nil, errClient } authIndex = strings.TrimSpace(authIndex) if authIndex == "" { @@ -247,7 +444,7 @@ func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, return nil, err } - raw, err := c.cmd.Get(ctx, string(keyBytes)).Bytes() + raw, err := cmd.Get(ctx, string(keyBytes)).Bytes() if errors.Is(err, redis.Nil) { return nil, ErrAuthNotFound } @@ -261,23 +458,25 @@ func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, } func (c *Client) LPushUsage(ctx context.Context, payload []byte) error { - if err := c.ensureClients(); err != nil { - return err + cmd, errClient := c.commandClient() + if errClient != nil { + return errClient } if len(payload) == 0 { return nil } - return c.cmd.LPush(ctx, redisKeyUsage, payload).Err() + return cmd.LPush(ctx, redisKeyUsage, payload).Err() } func (c *Client) RPushRequestLog(ctx context.Context, payload []byte) error { - if err := c.ensureClients(); err != nil { - return err + cmd, errClient := c.commandClient() + if errClient != nil { + return errClient } if len(payload) == 0 { return nil } - return c.cmd.RPush(ctx, redisKeyRequestLog, payload).Err() + return cmd.RPush(ctx, redisKeyRequestLog, payload).Err() } // StartConfigSubscriber connects to home, fetches config once via GET config, then subscribes to @@ -312,12 +511,14 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte if errEnsure := c.ensureClients(); errEnsure != nil { log.Warn("unable to connect to home control center, retrying in 1 second") + c.markReconnectFailure("connect") sleepWithContext(ctx, homeReconnectInterval) continue } if errPing := c.Ping(ctx); errPing != nil { log.Warn("unable to connect to home control center, retrying in 1 second") + c.markReconnectFailure("ping") sleepWithContext(ctx, homeReconnectInterval) continue } @@ -325,6 +526,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte raw, errGet := c.GetConfig(ctx) if errGet != nil { log.Warn("unable to fetch config from home control center, retrying in 1 second") + c.markReconnectFailure("config fetch") sleepWithContext(ctx, homeReconnectInterval) continue } @@ -334,13 +536,16 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte continue } - if c.sub == nil { + sub, errSubClient := c.subscriptionClient() + if errSubClient != nil { + c.markReconnectFailure("subscribe client") sleepWithContext(ctx, homeReconnectInterval) continue } - pubsub := c.sub.Subscribe(ctx, redisChannelConfig) + pubsub := sub.Subscribe(ctx, redisChannelConfig) if pubsub == nil { + c.markReconnectFailure("subscribe") sleepWithContext(ctx, homeReconnectInterval) continue } @@ -348,10 +553,12 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte // Ensure the subscription is established before marking heartbeat OK. if _, errReceive := pubsub.Receive(ctx); errReceive != nil { _ = pubsub.Close() + c.markReconnectFailure("subscribe") sleepWithContext(ctx, homeReconnectInterval) continue } + c.resetReconnectFailures() c.heartbeatOK.Store(true) for { @@ -359,6 +566,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte if errMsg != nil { _ = pubsub.Close() c.heartbeatOK.Store(false) + c.markReconnectFailure("subscription") sleepWithContext(ctx, homeReconnectInterval) break }