mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-31 20:02:36 +08:00
feat(client): add cluster node failover and improve reconnection handling
- Introduced cluster node management with `clusterNode` and `clusterNodesEnvelope` types. - Added failover handling for reconnection failures with configurable threshold (`homeReconnectFailoverThreshold`). - Implemented node switching and dynamic cluster target updates. - Enhanced Redis client management with centralized locking for concurrency safety. - Updated configuration refresh logic to prioritize the best cluster node. - Improved debug logging for reconnect failures and node switching.
This commit is contained in:
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -22,7 +24,8 @@ const (
|
||||
redisKeyUsage = "usage"
|
||||
redisKeyRequestLog = "request-log"
|
||||
|
||||
homeReconnectInterval = time.Second
|
||||
homeReconnectInterval = time.Second
|
||||
homeReconnectFailoverThreshold = 3
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -34,23 +37,48 @@ var (
|
||||
ErrModelsNotFound = errors.New("home models not found")
|
||||
)
|
||||
|
||||
type clusterNode struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
ClientCount int `json:"client_count"`
|
||||
IsMaster bool `json:"is_master"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
}
|
||||
|
||||
type clusterNodesEnvelope struct {
|
||||
OK bool `json:"ok"`
|
||||
Nodes []clusterNode `json:"nodes"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
homeCfg config.HomeConfig
|
||||
mu sync.Mutex
|
||||
|
||||
homeCfg config.HomeConfig
|
||||
seedHost string
|
||||
seedPort int
|
||||
|
||||
cmd *redis.Client
|
||||
sub *redis.Client
|
||||
|
||||
heartbeatOK atomic.Bool
|
||||
heartbeatOK atomic.Bool
|
||||
clusterNodes []clusterNode
|
||||
reconnectFailures int
|
||||
}
|
||||
|
||||
func New(homeCfg config.HomeConfig) *Client {
|
||||
return &Client{homeCfg: homeCfg}
|
||||
return &Client{
|
||||
homeCfg: homeCfg,
|
||||
seedHost: strings.TrimSpace(homeCfg.Host),
|
||||
seedPort: homeCfg.Port,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Enabled() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.homeCfg.Enabled
|
||||
}
|
||||
|
||||
@@ -69,6 +97,12 @@ func (c *Client) Close() {
|
||||
return
|
||||
}
|
||||
c.heartbeatOK.Store(false)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closeClientsLocked()
|
||||
}
|
||||
|
||||
func (c *Client) closeClientsLocked() {
|
||||
if c.cmd != nil {
|
||||
_ = c.cmd.Close()
|
||||
}
|
||||
@@ -83,6 +117,12 @@ func (c *Client) addr() (string, bool) {
|
||||
if c == nil {
|
||||
return "", false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.addrLocked()
|
||||
}
|
||||
|
||||
func (c *Client) addrLocked() (string, bool) {
|
||||
host := strings.TrimSpace(c.homeCfg.Host)
|
||||
if host == "" {
|
||||
return "", false
|
||||
@@ -100,7 +140,10 @@ func (c *Client) ensureClients() error {
|
||||
if !c.Enabled() {
|
||||
return ErrDisabled
|
||||
}
|
||||
addr, ok := c.addr()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
addr, ok := c.addrLocked()
|
||||
if !ok {
|
||||
return fmt.Errorf("home: invalid address (host=%q port=%d)", c.homeCfg.Host, c.homeCfg.Port)
|
||||
}
|
||||
@@ -120,21 +163,172 @@ func (c *Client) ensureClients() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) commandClient() (*redis.Client, error) {
|
||||
if errEnsure := c.ensureClients(); errEnsure != nil {
|
||||
return nil, errEnsure
|
||||
}
|
||||
c.mu.Lock()
|
||||
cmd := c.cmd
|
||||
c.mu.Unlock()
|
||||
if cmd == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (c *Client) subscriptionClient() (*redis.Client, error) {
|
||||
if errEnsure := c.ensureClients(); errEnsure != nil {
|
||||
return nil, errEnsure
|
||||
}
|
||||
c.mu.Lock()
|
||||
sub := c.sub
|
||||
c.mu.Unlock()
|
||||
if sub == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return errClient
|
||||
}
|
||||
if c.cmd == nil {
|
||||
return ErrNotConnected
|
||||
return cmd.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
func (c *Client) refreshBestClusterNode(ctx context.Context) {
|
||||
switched, errRefresh := c.refreshClusterNodes(ctx)
|
||||
if errRefresh != nil {
|
||||
log.Debugf("home cluster nodes unavailable: %v", errRefresh)
|
||||
return
|
||||
}
|
||||
return c.cmd.Ping(ctx).Err()
|
||||
if switched {
|
||||
if addr, ok := c.addr(); ok {
|
||||
log.Infof("home cluster target switched to %s", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) refreshClusterNodes(ctx context.Context) (bool, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return false, errClient
|
||||
}
|
||||
raw, errDo := cmd.Do(ctx, "CLUSTER", "NODES").Text()
|
||||
if errDo != nil {
|
||||
return false, errDo
|
||||
}
|
||||
|
||||
var envelope clusterNodesEnvelope
|
||||
if errUnmarshal := json.Unmarshal([]byte(raw), &envelope); errUnmarshal != nil {
|
||||
return false, errUnmarshal
|
||||
}
|
||||
nodes := normalizeClusterNodes(envelope.Nodes)
|
||||
if len(nodes) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.clusterNodes = nodes
|
||||
c.reconnectFailures = 0
|
||||
return c.switchToNodeLocked(nodes[0]), nil
|
||||
}
|
||||
|
||||
func normalizeClusterNodes(nodes []clusterNode) []clusterNode {
|
||||
out := make([]clusterNode, 0, len(nodes))
|
||||
for _, node := range nodes {
|
||||
node.IP = strings.TrimSpace(node.IP)
|
||||
if node.IP == "" || node.Port <= 0 {
|
||||
continue
|
||||
}
|
||||
if node.ClientCount < 0 {
|
||||
node.ClientCount = 0
|
||||
}
|
||||
out = append(out, node)
|
||||
}
|
||||
sort.SliceStable(out, func(i, j int) bool {
|
||||
return out[i].ClientCount < out[j].ClientCount
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *Client) switchToNodeLocked(node clusterNode) bool {
|
||||
host := strings.TrimSpace(node.IP)
|
||||
if host == "" || node.Port <= 0 {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(c.homeCfg.Host) == host && c.homeCfg.Port == node.Port {
|
||||
return false
|
||||
}
|
||||
c.homeCfg.Host = host
|
||||
c.homeCfg.Port = node.Port
|
||||
c.closeClientsLocked()
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) markReconnectFailure(reason string) {
|
||||
switched, addr := c.failoverAfterReconnectFailure()
|
||||
if switched {
|
||||
log.Warnf("home control center unavailable after repeated %s failures; switching to %s", reason, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) failoverAfterReconnectFailure() (bool, string) {
|
||||
if c == nil {
|
||||
return false, ""
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.reconnectFailures++
|
||||
if c.reconnectFailures < homeReconnectFailoverThreshold {
|
||||
return false, ""
|
||||
}
|
||||
c.reconnectFailures = 0
|
||||
|
||||
currentHost := strings.TrimSpace(c.homeCfg.Host)
|
||||
currentPort := c.homeCfg.Port
|
||||
candidates := append([]clusterNode(nil), c.clusterNodes...)
|
||||
if strings.TrimSpace(c.seedHost) != "" && c.seedPort > 0 {
|
||||
candidates = append(candidates, clusterNode{IP: c.seedHost, Port: c.seedPort})
|
||||
}
|
||||
for _, node := range candidates {
|
||||
host := strings.TrimSpace(node.IP)
|
||||
if host == "" || node.Port <= 0 {
|
||||
continue
|
||||
}
|
||||
if host == currentHost && node.Port == currentPort {
|
||||
continue
|
||||
}
|
||||
if c.switchToNodeLocked(clusterNode{IP: host, Port: node.Port}) {
|
||||
addr, _ := c.addrLocked()
|
||||
return true, addr
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (c *Client) resetReconnectFailures() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.reconnectFailures = 0
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) GetConfig(ctx context.Context) ([]byte, error) {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return nil, err
|
||||
c.refreshBestClusterNode(ctx)
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return nil, errClient
|
||||
}
|
||||
raw, err := c.cmd.Get(ctx, redisKeyConfig).Bytes()
|
||||
raw, err := cmd.Get(ctx, redisKeyConfig).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrConfigNotFound
|
||||
}
|
||||
@@ -148,10 +342,11 @@ func (c *Client) GetConfig(ctx context.Context) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (c *Client) GetModels(ctx context.Context) ([]byte, error) {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return nil, err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return nil, errClient
|
||||
}
|
||||
raw, err := c.cmd.Get(ctx, redisKeyModels).Bytes()
|
||||
raw, err := cmd.Get(ctx, redisKeyModels).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrModelsNotFound
|
||||
}
|
||||
@@ -204,8 +399,9 @@ func newAuthDispatchRequest(requestedModel string, sessionID string, headers htt
|
||||
}
|
||||
|
||||
func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID string, headers http.Header, count int) ([]byte, error) {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return nil, err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return nil, errClient
|
||||
}
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
@@ -217,7 +413,7 @@ func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
raw, err := c.cmd.RPop(ctx, string(keyBytes)).Bytes()
|
||||
raw, err := cmd.RPop(ctx, string(keyBytes)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrAuthNotFound
|
||||
}
|
||||
@@ -231,8 +427,9 @@ func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID
|
||||
}
|
||||
|
||||
func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, error) {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return nil, err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return nil, errClient
|
||||
}
|
||||
authIndex = strings.TrimSpace(authIndex)
|
||||
if authIndex == "" {
|
||||
@@ -247,7 +444,7 @@ func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
raw, err := c.cmd.Get(ctx, string(keyBytes)).Bytes()
|
||||
raw, err := cmd.Get(ctx, string(keyBytes)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrAuthNotFound
|
||||
}
|
||||
@@ -261,23 +458,25 @@ func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte,
|
||||
}
|
||||
|
||||
func (c *Client) LPushUsage(ctx context.Context, payload []byte) error {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return errClient
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.cmd.LPush(ctx, redisKeyUsage, payload).Err()
|
||||
return cmd.LPush(ctx, redisKeyUsage, payload).Err()
|
||||
}
|
||||
|
||||
func (c *Client) RPushRequestLog(ctx context.Context, payload []byte) error {
|
||||
if err := c.ensureClients(); err != nil {
|
||||
return err
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return errClient
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.cmd.RPush(ctx, redisKeyRequestLog, payload).Err()
|
||||
return cmd.RPush(ctx, redisKeyRequestLog, payload).Err()
|
||||
}
|
||||
|
||||
// StartConfigSubscriber connects to home, fetches config once via GET config, then subscribes to
|
||||
@@ -312,12 +511,14 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
|
||||
if errEnsure := c.ensureClients(); errEnsure != nil {
|
||||
log.Warn("unable to connect to home control center, retrying in 1 second")
|
||||
c.markReconnectFailure("connect")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if errPing := c.Ping(ctx); errPing != nil {
|
||||
log.Warn("unable to connect to home control center, retrying in 1 second")
|
||||
c.markReconnectFailure("ping")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -325,6 +526,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
raw, errGet := c.GetConfig(ctx)
|
||||
if errGet != nil {
|
||||
log.Warn("unable to fetch config from home control center, retrying in 1 second")
|
||||
c.markReconnectFailure("config fetch")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -334,13 +536,16 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
continue
|
||||
}
|
||||
|
||||
if c.sub == nil {
|
||||
sub, errSubClient := c.subscriptionClient()
|
||||
if errSubClient != nil {
|
||||
c.markReconnectFailure("subscribe client")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
pubsub := c.sub.Subscribe(ctx, redisChannelConfig)
|
||||
pubsub := sub.Subscribe(ctx, redisChannelConfig)
|
||||
if pubsub == nil {
|
||||
c.markReconnectFailure("subscribe")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -348,10 +553,12 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
// Ensure the subscription is established before marking heartbeat OK.
|
||||
if _, errReceive := pubsub.Receive(ctx); errReceive != nil {
|
||||
_ = pubsub.Close()
|
||||
c.markReconnectFailure("subscribe")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
c.resetReconnectFailures()
|
||||
c.heartbeatOK.Store(true)
|
||||
|
||||
for {
|
||||
@@ -359,6 +566,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
if errMsg != nil {
|
||||
_ = pubsub.Close()
|
||||
c.heartbeatOK.Store(false)
|
||||
c.markReconnectFailure("subscription")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
break
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user