mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-31 11:51:20 +08:00
test: remove unused Redis protocol tests and helpers
- Removed obsolete Redis protocol test cases and helper functions that were no longer relevant due to recent architecture changes. - Streamlined remaining test files to align with updated Redis handling and connection management logic.
This commit is contained in:
@@ -218,6 +218,10 @@ OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼
|
||||
|
||||
一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
|
||||
|
||||
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
|
||||
|
||||
这是一个使用 Tauri 2 + Vue 3 构建的工具,用于管理多个 OpenAI Codex 桌面账户。它可以在已保存的 ChatGPT/Codex 认证配置之间切换,实时查看 5 小时和每周配额使用情况,验证 token 健康状态,查看当前账户详情,并在无需手动复制的情况下导入或保存 auth.json 文件。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -217,6 +217,10 @@ OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:
|
||||
|
||||
上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
|
||||
|
||||
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
|
||||
|
||||
Tauri 2 + Vue 3で構築された、複数のOpenAI Codexデスクトップアカウントを管理するためのツールです。保存済みのChatGPT/Codex認証プロファイルを切り替え、5時間および週次クォータ使用量をリアルタイムで確認し、tokenの状態を検証し、現在のアカウント詳細を表示し、手動コピーなしでauth.jsonファイルをインポートまたは保存できます。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseHomeFlagConfigHostPort(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Enabled {
|
||||
t.Fatal("Enabled = false, want true")
|
||||
}
|
||||
if cfg.Host != "home.example.com" {
|
||||
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||
}
|
||||
if cfg.Port != 8327 {
|
||||
t.Fatalf("Port = %d, want 8327", cfg.Port)
|
||||
}
|
||||
if cfg.Password != "secret" {
|
||||
t.Fatalf("Password = %q, want secret", cfg.Password)
|
||||
}
|
||||
if cfg.TLS.Enable {
|
||||
t.Fatal("TLS.Enable = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigRediss(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Host != "home.example.com" {
|
||||
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||
}
|
||||
if cfg.Port != 444 {
|
||||
t.Fatalf("Port = %d, want 444", cfg.Port)
|
||||
}
|
||||
if cfg.Password != "url-secret" {
|
||||
t.Fatalf("Password = %q, want url-secret", cfg.Password)
|
||||
}
|
||||
if !cfg.TLS.Enable {
|
||||
t.Fatal("TLS.Enable = false, want true")
|
||||
}
|
||||
if cfg.TLS.ServerName != "home.example.com" {
|
||||
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
|
||||
}
|
||||
if !cfg.TLS.InsecureSkipVerify {
|
||||
t.Fatal("TLS.InsecureSkipVerify = false, want true")
|
||||
}
|
||||
if cfg.TLS.CACert != "C:/certs/ca.pem" {
|
||||
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Password != "flag-secret" {
|
||||
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.DisableClusterDiscovery {
|
||||
t.Fatal("DisableClusterDiscovery = false, want true")
|
||||
}
|
||||
}
|
||||
@@ -10,11 +10,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -53,120 +51,6 @@ func init() {
|
||||
buildinfo.BuildDate = BuildDate
|
||||
}
|
||||
|
||||
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||
rawAddr = strings.TrimSpace(rawAddr)
|
||||
if rawAddr == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("address is empty")
|
||||
}
|
||||
|
||||
if strings.Contains(rawAddr, "://") {
|
||||
return parseHomeURLConfig(rawAddr, password)
|
||||
}
|
||||
|
||||
host, portStr, errSplit := net.SplitHostPort(rawAddr)
|
||||
if errSplit != nil {
|
||||
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
|
||||
}
|
||||
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||
}
|
||||
|
||||
port, errPort := parseHomePort(portStr)
|
||||
if errPort != nil {
|
||||
return config.HomeConfig{}, errPort
|
||||
}
|
||||
|
||||
return config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||
parsed, errParse := url.Parse(rawAddr)
|
||||
if errParse != nil {
|
||||
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "redis" && scheme != "rediss" {
|
||||
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||
}
|
||||
|
||||
port, errPort := parseHomePort(parsed.Port())
|
||||
if errPort != nil {
|
||||
return config.HomeConfig{}, errPort
|
||||
}
|
||||
|
||||
if password == "" && parsed.User != nil {
|
||||
if urlPassword, ok := parsed.User.Password(); ok {
|
||||
password = urlPassword
|
||||
}
|
||||
}
|
||||
|
||||
homeCfg := config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
}
|
||||
query := parsed.Query()
|
||||
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
|
||||
|
||||
if scheme == "rediss" {
|
||||
homeCfg.TLS.Enable = true
|
||||
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
|
||||
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
|
||||
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
|
||||
}
|
||||
|
||||
return homeCfg, nil
|
||||
}
|
||||
|
||||
func parseHomePort(rawPort string) (int, error) {
|
||||
rawPort = strings.TrimSpace(rawPort)
|
||||
if rawPort == "" {
|
||||
return 0, fmt.Errorf("port is empty")
|
||||
}
|
||||
|
||||
port, errPort := strconv.Atoi(rawPort)
|
||||
if errPort != nil || port <= 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port %q", rawPort)
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func firstHomeQueryValue(values url.Values, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := values.Get(key); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
value := strings.TrimSpace(values.Get(key))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
parsed, errParse := strconv.ParseBool(value)
|
||||
return errParse == nil && parsed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||
// service based on the provided flags (login, codex-login, or server mode).
|
||||
@@ -188,8 +72,6 @@ func main() {
|
||||
var vertexImportPrefix string
|
||||
var configPath string
|
||||
var password string
|
||||
var homeAddr string
|
||||
var homePassword string
|
||||
var homeJWT string
|
||||
var homeDisableClusterDiscovery bool
|
||||
var tuiMode bool
|
||||
@@ -211,10 +93,8 @@ func main() {
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
|
||||
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
|
||||
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
|
||||
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
|
||||
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home-jwt address")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||
@@ -302,17 +182,6 @@ func main() {
|
||||
}
|
||||
writableBase := util.WritablePath()
|
||||
|
||||
// Allow env var fallback for home flags so they can be configured without command args.
|
||||
if strings.TrimSpace(homeAddr) == "" {
|
||||
if v, ok := lookupEnv("HOME_ADDR", "home_addr"); ok {
|
||||
homeAddr = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(homePassword) == "" {
|
||||
if v, ok := lookupEnv("HOME_PASSWORD", "home_password"); ok {
|
||||
homePassword = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(homeJWT) == "" {
|
||||
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
|
||||
homeJWT = v
|
||||
@@ -426,53 +295,6 @@ func main() {
|
||||
configFilePath = filepath.Join(wd, "config.yaml")
|
||||
}
|
||||
|
||||
// Local stores are intentionally disabled when config is loaded from home.
|
||||
usePostgresStore = false
|
||||
useObjectStore = false
|
||||
useGitStore = false
|
||||
} else if strings.TrimSpace(homeAddr) != "" {
|
||||
configLoadedFromHome = true
|
||||
trimmedHomePassword := strings.TrimSpace(homePassword)
|
||||
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
|
||||
if errHomeCfg != nil {
|
||||
log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
|
||||
return
|
||||
}
|
||||
if homeDisableClusterDiscovery {
|
||||
homeCfg.DisableClusterDiscovery = true
|
||||
}
|
||||
homeClient := home.New(homeCfg)
|
||||
defer homeClient.Close()
|
||||
|
||||
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
raw, errGetConfig := homeClient.GetConfig(ctxHome)
|
||||
cancelHome()
|
||||
if errGetConfig != nil {
|
||||
log.Errorf("failed to fetch config from home: %v", errGetConfig)
|
||||
return
|
||||
}
|
||||
|
||||
parsed, errParseConfig := config.ParseConfigBytes(raw)
|
||||
if errParseConfig != nil {
|
||||
log.Errorf("failed to parse config payload from home: %v", errParseConfig)
|
||||
return
|
||||
}
|
||||
if parsed == nil {
|
||||
parsed = &config.Config{}
|
||||
}
|
||||
parsed.Home = homeCfg
|
||||
parsed.Port = 8317 // Default to 8317 for home mode, can be overridden by home config
|
||||
parsed.UsageStatisticsEnabled = true
|
||||
cfg = parsed
|
||||
|
||||
// Keep a non-empty config path for downstream components (log paths, management assets, etc),
|
||||
// but do not require the file to exist when loading config from home.
|
||||
if strings.TrimSpace(configPath) != "" {
|
||||
configFilePath = configPath
|
||||
} else {
|
||||
configFilePath = filepath.Join(wd, "config.yaml")
|
||||
}
|
||||
|
||||
// Local stores are intentionally disabled when config is loaded from home.
|
||||
usePostgresStore = false
|
||||
useObjectStore = false
|
||||
|
||||
@@ -11,26 +11,6 @@ tls:
|
||||
cert: ""
|
||||
key: ""
|
||||
|
||||
# Optional "home" control plane integration over Redis protocol.
|
||||
home:
|
||||
enabled: false
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
|
||||
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
|
||||
disable-cluster-discovery: false
|
||||
# Optional TLS for the outbound Redis connection to the home control plane.
|
||||
# Enable this when connecting through rediss:// or an SSL stream proxy.
|
||||
tls:
|
||||
enable: false
|
||||
# Optional SNI/certificate name override. Leave empty to use the configured home host.
|
||||
server-name: ""
|
||||
# Trust a private CA bundle in addition to system roots.
|
||||
ca-cert: ""
|
||||
# Only for testing self-signed endpoints; disables certificate verification.
|
||||
insecure-skip-verify: false
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
# Whether to allow remote (non-localhost) management access.
|
||||
@@ -86,8 +66,8 @@ error-logs-max-files: 10
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP).
|
||||
# Note: the in-process Redis RESP usage output is disabled when home.enabled is true.
|
||||
# How long (in seconds) usage queue items are retained in memory for the Management API.
|
||||
# The local Redis RESP usage output is disabled.
|
||||
# Default: 60. Max: 3600.
|
||||
redis-usage-queue-retention-seconds: 60
|
||||
|
||||
|
||||
@@ -103,20 +103,8 @@ func (s *Server) routeMuxConnection(conn net.Conn, httpListener *muxListener) {
|
||||
}
|
||||
|
||||
if isRedisRESPPrefix(prefix[0]) {
|
||||
if s.cfg != nil && s.cfg.Home.Enabled {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while home mode is enabled: %v", errClose)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
|
||||
}
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
s.handleRedisConnection(conn, reader)
|
||||
s.handleRedisConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,25 +2,11 @@ package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const redisUsageChannel = "usage"
|
||||
|
||||
type redisSubscriptionCommand struct {
|
||||
args []string
|
||||
err error
|
||||
}
|
||||
|
||||
func isRedisRESPPrefix(prefix byte) bool {
|
||||
switch prefix {
|
||||
case '*', '$', '+', '-', ':':
|
||||
@@ -30,13 +16,11 @@ func isRedisRESPPrefix(prefix byte) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
if s == nil || conn == nil || reader == nil {
|
||||
func (s *Server) handleRedisConnection(conn net.Conn) {
|
||||
if s == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
|
||||
authed := false
|
||||
writer := bufio.NewWriter(conn)
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
@@ -44,432 +28,10 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
}
|
||||
}()
|
||||
|
||||
flush := func() bool {
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
_ = writeRedisError(writer, "ERR RESP AUTH disabled; use mTLS")
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.Home.Enabled {
|
||||
_ = writeRedisError(writer, "ERR redis usage output disabled in home mode")
|
||||
_ = writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = writeRedisError(writer, "ERR "+err.Error())
|
||||
_ = writer.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(args) == 0 {
|
||||
_ = writeRedisError(writer, "ERR empty command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(strings.TrimSpace(args[0]))
|
||||
|
||||
if cmd != "AUTH" && !authed {
|
||||
if s.mgmt != nil {
|
||||
_, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "")
|
||||
if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") {
|
||||
_ = writeRedisError(writer, "ERR "+errMsg)
|
||||
} else {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
}
|
||||
} else {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
}
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "AUTH":
|
||||
password, ok := parseAuthPassword(args)
|
||||
if !ok {
|
||||
if s.mgmt != nil {
|
||||
_, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "")
|
||||
if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") {
|
||||
_ = writeRedisError(writer, "ERR "+errMsg)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if s.mgmt == nil {
|
||||
_ = writeRedisError(writer, "ERR remote management disabled")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password)
|
||||
if !allowed {
|
||||
_ = writeRedisError(writer, "ERR "+errMsg)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
authed = true
|
||||
_ = writeRedisSimpleString(writer, "OK")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
case "SUBSCRIBE":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
channel, ok := parseSubscribeChannel(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(channel, redisUsageChannel) {
|
||||
_ = writeRedisError(writer, fmt.Sprintf("ERR unsupported channel '%s'", channel))
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
messages, unsubscribe := redisqueue.SubscribeUsage()
|
||||
if errWrite := writeRedisPubSubSubscribe(writer, redisUsageChannel, 1); errWrite != nil {
|
||||
unsubscribe()
|
||||
log.Errorf("redis protocol subscribe response error: %v", errWrite)
|
||||
return
|
||||
}
|
||||
if !flush() {
|
||||
unsubscribe()
|
||||
return
|
||||
}
|
||||
s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe)
|
||||
return
|
||||
case "LPOP", "RPOP":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
count, hasCount, ok := parsePopCount(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count <= 0 {
|
||||
_ = writeRedisError(writer, "ERR value is not an integer or out of range")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
items := redisqueue.PopOldest(count)
|
||||
if hasCount {
|
||||
_ = writeRedisArrayOfBulkStrings(writer, items)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(items) == 0 {
|
||||
_ = writeRedisNilBulkString(writer)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = writeRedisBulkString(writer, items[0])
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) streamRedisUsageSubscription(reader *bufio.Reader, writer *bufio.Writer, messages <-chan []byte, unsubscribe func()) {
|
||||
if unsubscribe == nil {
|
||||
return
|
||||
}
|
||||
defer unsubscribe()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
commands := make(chan redisSubscriptionCommand, 1)
|
||||
go readRedisSubscriptionCommands(reader, commands, done)
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-messages:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if errWrite := writeRedisPubSubMessage(writer, redisUsageChannel, msg); errWrite != nil {
|
||||
log.Errorf("redis protocol publish message error: %v", errWrite)
|
||||
return
|
||||
}
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
return
|
||||
}
|
||||
case command, ok := <-commands:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
keepOpen := handleRedisSubscriptionCommand(writer, command)
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
return
|
||||
}
|
||||
if !keepOpen {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSubscriptionCommand, done <-chan struct{}) {
|
||||
defer close(commands)
|
||||
|
||||
for {
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
select {
|
||||
case commands <- redisSubscriptionCommand{err: err}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case commands <- redisSubscriptionCommand{args: args}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleRedisSubscriptionCommand(writer *bufio.Writer, command redisSubscriptionCommand) bool {
|
||||
if command.err != nil {
|
||||
_ = writeRedisError(writer, "ERR "+command.err.Error())
|
||||
return false
|
||||
}
|
||||
if len(command.args) == 0 {
|
||||
_ = writeRedisError(writer, "ERR empty command")
|
||||
return true
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(strings.TrimSpace(command.args[0]))
|
||||
switch cmd {
|
||||
case "PING":
|
||||
payload := []byte(nil)
|
||||
if len(command.args) > 1 {
|
||||
payload = []byte(command.args[1])
|
||||
}
|
||||
_ = writeRedisPubSubPong(writer, payload)
|
||||
return true
|
||||
case "UNSUBSCRIBE":
|
||||
_ = writeRedisPubSubUnsubscribe(writer, redisUsageChannel, 0)
|
||||
return false
|
||||
case "QUIT":
|
||||
_ = writeRedisSimpleString(writer, "OK")
|
||||
return false
|
||||
default:
|
||||
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
||||
if addr == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var host string
|
||||
switch a := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
if a != nil && a.IP != nil {
|
||||
if ip4 := a.IP.To4(); ip4 != nil {
|
||||
host = ip4.String()
|
||||
} else {
|
||||
host = a.IP.String()
|
||||
}
|
||||
}
|
||||
default:
|
||||
host = addr.String()
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
host = strings.TrimSpace(host)
|
||||
if raw, _, ok := strings.Cut(host, "%"); ok {
|
||||
host = raw
|
||||
}
|
||||
if parsed := net.ParseIP(host); parsed != nil {
|
||||
if ip4 := parsed.To4(); ip4 != nil {
|
||||
host = ip4.String()
|
||||
} else {
|
||||
host = parsed.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host = strings.TrimSpace(host)
|
||||
localClient = host == "127.0.0.1" || host == "::1"
|
||||
return host, localClient
|
||||
}
|
||||
|
||||
func parseAuthPassword(args []string) (string, bool) {
|
||||
switch len(args) {
|
||||
case 2:
|
||||
return args[1], true
|
||||
case 3:
|
||||
// Support AUTH <username> <password> by ignoring username for compatibility.
|
||||
return args[2], true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseSubscribeChannel(args []string) (string, bool) {
|
||||
if len(args) != 2 {
|
||||
return "", false
|
||||
}
|
||||
return strings.TrimSpace(args[1]), true
|
||||
}
|
||||
|
||||
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
||||
if len(args) != 2 && len(args) != 3 {
|
||||
return 0, false, false
|
||||
}
|
||||
if len(args) == 2 {
|
||||
return 1, false, true
|
||||
}
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
|
||||
if err != nil {
|
||||
return 0, true, true
|
||||
}
|
||||
return parsed, true, true
|
||||
}
|
||||
|
||||
func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil || count < 0 {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
args := make([]string, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
value, err := readRESPString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, value)
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func readRESPString(reader *bufio.Reader) (string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
return readRESPBulkString(reader)
|
||||
case '+', ':':
|
||||
return readRESPLine(reader)
|
||||
default:
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func readRESPBulkString(reader *bufio.Reader) (string, error) {
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
if length < 0 {
|
||||
return "", nil
|
||||
}
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
func readRESPLine(reader *bufio.Reader) (string, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func writeRedisSimpleString(writer *bufio.Writer, value string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("+" + value + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisError(writer *bufio.Writer, message string) error {
|
||||
@@ -479,108 +41,3 @@ func writeRedisError(writer *bufio.Writer, message string) error {
|
||||
_, err := writer.WriteString("-" + message + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisNilBulkString(writer *bufio.Writer) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("$-1\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if payload == nil {
|
||||
return writeRedisNilBulkString(writer)
|
||||
}
|
||||
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := writer.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range items {
|
||||
if err := writeRedisBulkString(writer, items[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeRedisInteger(writer *bufio.Writer, value int) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisArrayHeader(writer *bufio.Writer, count int) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRedisInteger(writer, count)
|
||||
}
|
||||
|
||||
func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRedisInteger(writer, count)
|
||||
}
|
||||
|
||||
func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("message")); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRedisBulkString(writer, payload)
|
||||
}
|
||||
|
||||
func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error {
|
||||
if err := writeRedisArrayHeader(writer, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("pong")); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRedisBulkString(writer, payload)
|
||||
}
|
||||
|
||||
@@ -3,14 +3,9 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -18,18 +13,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
|
||||
)
|
||||
|
||||
type remoteAddrConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *remoteAddrConn) RemoteAddr() net.Addr {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
@@ -86,17 +69,6 @@ func readTestRESPLine(r *bufio.Reader) (string, error) {
|
||||
return strings.TrimSuffix(line, "\r\n"), nil
|
||||
}
|
||||
|
||||
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '+' {
|
||||
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
@@ -108,171 +80,6 @@ func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '$' {
|
||||
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
|
||||
}
|
||||
if length == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
if length < -1 {
|
||||
return nil, fmt.Errorf("invalid bulk string length %d", length)
|
||||
}
|
||||
|
||||
payload := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload[length] != '\r' || payload[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("invalid bulk string terminator")
|
||||
}
|
||||
return payload[:length], nil
|
||||
}
|
||||
|
||||
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
}
|
||||
if count < 0 {
|
||||
return nil, fmt.Errorf("invalid array length %d", count)
|
||||
}
|
||||
|
||||
out := make([][]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
item, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readTestRESPInteger(r *bufio.Reader) (int, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if prefix != ':' {
|
||||
return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
value, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid integer %q: %v", line, err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func readTestRESPArrayHeader(r *bufio.Reader) (int, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
}
|
||||
if count < 0 {
|
||||
return 0, fmt.Errorf("invalid array length %d", count)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
|
||||
count, err := readTestRESPArrayHeader(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if count != 3 {
|
||||
return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count)
|
||||
}
|
||||
|
||||
kind, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if string(kind) != "subscribe" {
|
||||
return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind))
|
||||
}
|
||||
|
||||
channel, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
subscriptions, err := readTestRESPInteger(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return string(channel), subscriptions, nil
|
||||
}
|
||||
|
||||
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
|
||||
count, err := readTestRESPArrayHeader(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if count != 3 {
|
||||
return "", nil, fmt.Errorf("message array length = %d, want 3", count)
|
||||
}
|
||||
|
||||
kind, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if string(kind) != "message" {
|
||||
return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind))
|
||||
}
|
||||
|
||||
channel, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
payload, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return string(channel), payload, nil
|
||||
}
|
||||
|
||||
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
redisqueue.SetEnabled(false)
|
||||
@@ -296,13 +103,19 @@ func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||
t.Fatalf("failed to write RESP command: %v", errWrite)
|
||||
}
|
||||
|
||||
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
|
||||
t.Fatalf("failed to read disabled RESP error: %v", err)
|
||||
} else if msg != "ERR RESP AUTH disabled; use mTLS" {
|
||||
t.Fatalf("unexpected disabled RESP error: %q", msg)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed when management is disabled")
|
||||
t.Fatalf("expected connection to be closed after disabled RESP error")
|
||||
}
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead)
|
||||
t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,17 +146,23 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) {
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
_ = writeTestRESPCommand(conn, "PING")
|
||||
|
||||
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
|
||||
t.Fatalf("failed to read disabled RESP error: %v", err)
|
||||
} else if msg != "ERR RESP AUTH disabled; use mTLS" {
|
||||
t.Fatalf("unexpected disabled RESP error: %q", msg)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed when home mode is enabled")
|
||||
t.Fatalf("expected connection to be closed after disabled RESP error")
|
||||
}
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed when home mode is enabled, got timeout: %v", errRead)
|
||||
t.Fatalf("expected connection to be closed after disabled RESP error, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
func TestRedisProtocol_AUTH_DisabledAndClosesConnection(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
@@ -368,369 +187,21 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected AUTH response: %q", msg)
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read disabled AUTH error: %v", err)
|
||||
} else if msg != "ERR RESP AUTH disabled; use mTLS" {
|
||||
t.Fatalf("unexpected disabled AUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if !redisqueue.Enabled() {
|
||||
t.Fatalf("expected redisqueue to be enabled")
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed after disabled AUTH error")
|
||||
}
|
||||
redisqueue.Enqueue([]byte("a"))
|
||||
redisqueue.Enqueue([]byte("b"))
|
||||
redisqueue.Enqueue([]byte("c"))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read RPOP response: %v", err)
|
||||
} else if string(item) != "a" {
|
||||
t.Fatalf("unexpected RPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP response: %v", err)
|
||||
} else if string(item) != "b" {
|
||||
t.Fatalf("unexpected LPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP count command: %v", errWrite)
|
||||
}
|
||||
items, errItems := readRESPArrayOfBulkStrings(reader)
|
||||
if errItems != nil {
|
||||
t.Fatalf("failed to read RPOP count response: %v", errItems)
|
||||
}
|
||||
if len(items) != 1 || string(items[0]) != "c" {
|
||||
t.Fatalf("unexpected RPOP count items: %#v", items)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(reader)
|
||||
if errItem != nil {
|
||||
t.Fatalf("failed to read LPOP empty response: %v", errItem)
|
||||
}
|
||||
if item != nil {
|
||||
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
|
||||
}
|
||||
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
|
||||
if errEmpty != nil {
|
||||
t.Fatalf("failed to read RPOP empty count response: %v", errEmpty)
|
||||
}
|
||||
if len(emptyItems) != 0 {
|
||||
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialFirst != nil {
|
||||
t.Fatalf("failed to dial first redis listener: %v", errDialFirst)
|
||||
}
|
||||
t.Cleanup(func() { _ = firstConn.Close() })
|
||||
firstReader := bufio.NewReader(firstConn)
|
||||
_ = firstConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write first AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected first AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite)
|
||||
}
|
||||
if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first SUBSCRIBE response: %v", err)
|
||||
} else if channel != "usage" || count != 1 {
|
||||
t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||
}
|
||||
|
||||
secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialSecond != nil {
|
||||
t.Fatalf("failed to dial second redis listener: %v", errDialSecond)
|
||||
}
|
||||
t.Cleanup(func() { _ = secondConn.Close() })
|
||||
secondReader := bufio.NewReader(secondConn)
|
||||
_ = secondConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write second AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected second AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite)
|
||||
}
|
||||
if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second SUBSCRIBE response: %v", err)
|
||||
} else if channel != "usage" || count != 1 {
|
||||
t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||
}
|
||||
|
||||
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||
|
||||
if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first pubsub message: %v", err)
|
||||
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||
t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload))
|
||||
}
|
||||
if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second pubsub message: %v", err)
|
||||
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||
t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload))
|
||||
}
|
||||
|
||||
popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialPop != nil {
|
||||
t.Fatalf("failed to dial pop redis listener: %v", errDialPop)
|
||||
}
|
||||
t.Cleanup(func() { _ = popConn.Close() })
|
||||
popReader := bufio.NewReader(popConn)
|
||||
_ = popConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write pop AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(popReader); err != nil {
|
||||
t.Fatalf("failed to read pop AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected pop AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write pop LPOP command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(popReader)
|
||||
if errItem != nil {
|
||||
t.Fatalf("failed to read pop LPOP response: %v", errItem)
|
||||
}
|
||||
if item != nil {
|
||||
t.Fatalf("expected subscribed usage to skip queue, got %q", string(item))
|
||||
}
|
||||
|
||||
managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil)
|
||||
managementReq.Header.Set("Authorization", "Bearer "+managementPassword)
|
||||
managementRR := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(managementRR, managementReq)
|
||||
if managementRR.Code != http.StatusOK {
|
||||
t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String())
|
||||
}
|
||||
var managementPayload []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal management usage response: %v", errUnmarshal)
|
||||
}
|
||||
if len(managementPayload) != 0 {
|
||||
t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() { _ = clientConn.Close() })
|
||||
t.Cleanup(func() { _ = serverConn.Close() })
|
||||
|
||||
fakeRemote := &net.TCPAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
|
||||
|
||||
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command after failures: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read LPOP banned error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected LPOP banned error: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() { _ = clientConn.Close() })
|
||||
t.Cleanup(func() { _ = serverConn.Close() })
|
||||
|
||||
fakeRemote := &net.TCPAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
|
||||
|
||||
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command after failures: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed after disabled AUTH error, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ type Config struct {
|
||||
// TLS config controls HTTPS server settings.
|
||||
TLS TLSConfig `yaml:"tls" json:"tls"`
|
||||
|
||||
// Home config enables the Redis-based control plane integration.
|
||||
Home HomeConfig `yaml:"home" json:"-"`
|
||||
// Home config is runtime-only and is populated from -home-jwt.
|
||||
Home HomeConfig `yaml:"-" json:"-"`
|
||||
|
||||
// RemoteManagement nests management-related options under 'remote-management'.
|
||||
RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"`
|
||||
@@ -69,8 +69,8 @@ type Config struct {
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
|
||||
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
|
||||
// RedisUsageQueueRetentionSeconds controls how long usage queue items are retained
|
||||
// in memory for Management API consumers.
|
||||
// Default: 60. Max: 3600.
|
||||
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package config
|
||||
|
||||
// HomeConfig configures the optional "home" control plane integration over Redis protocol.
|
||||
// HomeConfig stores runtime-only Home control plane settings from -home-jwt.
|
||||
type HomeConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Host string `yaml:"host" json:"-"`
|
||||
Port int `yaml:"port" json:"-"`
|
||||
Password string `yaml:"password" json:"-"`
|
||||
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
|
||||
TLS HomeTLSConfig `yaml:"tls" json:"-"`
|
||||
}
|
||||
|
||||
@@ -2,13 +2,12 @@ package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseConfigBytesHomeTLS(t *testing.T) {
|
||||
func TestParseConfigBytesIgnoresHomeConfig(t *testing.T) {
|
||||
cfg, err := ParseConfigBytes([]byte(`
|
||||
home:
|
||||
enabled: true
|
||||
host: home.example.com
|
||||
port: 444
|
||||
password: secret
|
||||
disable-cluster-discovery: true
|
||||
tls:
|
||||
enable: true
|
||||
@@ -20,31 +19,28 @@ home:
|
||||
t.Fatalf("ParseConfigBytes() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Home.Enabled {
|
||||
t.Fatal("Home.Enabled = false, want true")
|
||||
if cfg.Home.Enabled {
|
||||
t.Fatal("Home.Enabled = true, want false")
|
||||
}
|
||||
if cfg.Home.Host != "home.example.com" {
|
||||
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host)
|
||||
if cfg.Home.Host != "" {
|
||||
t.Fatalf("Home.Host = %q, want empty", cfg.Home.Host)
|
||||
}
|
||||
if cfg.Home.Port != 444 {
|
||||
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port)
|
||||
if cfg.Home.Port != 0 {
|
||||
t.Fatalf("Home.Port = %d, want 0", cfg.Home.Port)
|
||||
}
|
||||
if cfg.Home.Password != "secret" {
|
||||
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password)
|
||||
if cfg.Home.DisableClusterDiscovery {
|
||||
t.Fatal("Home.DisableClusterDiscovery = true, want false")
|
||||
}
|
||||
if !cfg.Home.DisableClusterDiscovery {
|
||||
t.Fatal("Home.DisableClusterDiscovery = false, want true")
|
||||
if cfg.Home.TLS.Enable {
|
||||
t.Fatal("Home.TLS.Enable = true, want false")
|
||||
}
|
||||
if !cfg.Home.TLS.Enable {
|
||||
t.Fatal("Home.TLS.Enable = false, want true")
|
||||
if cfg.Home.TLS.ServerName != "" {
|
||||
t.Fatalf("Home.TLS.ServerName = %q, want empty", cfg.Home.TLS.ServerName)
|
||||
}
|
||||
if cfg.Home.TLS.ServerName != "home.example.com" {
|
||||
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName)
|
||||
if cfg.Home.TLS.CACert != "" {
|
||||
t.Fatalf("Home.TLS.CACert = %q, want empty", cfg.Home.TLS.CACert)
|
||||
}
|
||||
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" {
|
||||
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert)
|
||||
}
|
||||
if !cfg.Home.TLS.InsecureSkipVerify {
|
||||
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
|
||||
if cfg.Home.TLS.InsecureSkipVerify {
|
||||
t.Fatal("Home.TLS.InsecureSkipVerify = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +180,6 @@ func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
|
||||
}
|
||||
return &redis.Options{
|
||||
Addr: addr,
|
||||
Password: c.homeCfg.Password,
|
||||
TLSConfig: tlsConfig,
|
||||
DialTimeout: homeRedisOperationTimeout,
|
||||
ReadTimeout: homeRedisOperationTimeout,
|
||||
|
||||
@@ -37,10 +37,9 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
|
||||
|
||||
func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
|
||||
client := New(config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: "127.0.0.1",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
Enabled: true,
|
||||
Host: "127.0.0.1",
|
||||
Port: 6379,
|
||||
})
|
||||
|
||||
client.mu.Lock()
|
||||
@@ -53,8 +52,8 @@ func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
|
||||
if options.TLSConfig != nil {
|
||||
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
|
||||
}
|
||||
if options.Password != "secret" {
|
||||
t.Fatalf("Password = %q, want secret", options.Password)
|
||||
if options.Password != "" {
|
||||
t.Fatalf("Password = %q, want empty", options.Password)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user