diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index ee96ed79b..af11366c3 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -207,84 +207,42 @@ func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, p } envSecret := h.envSecret - fail := func() {} - if !localClient { + now := time.Now() + h.attemptsMu.Lock() + ai := h.failedAttempts[clientIP] + if ai != nil && !ai.blockedUntil.IsZero() { + if now.Before(ai.blockedUntil) { + remaining := ai.blockedUntil.Sub(now).Round(time.Second) + h.attemptsMu.Unlock() + return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining) + } + // Ban expired, reset state + ai.blockedUntil = time.Time{} + ai.count = 0 + } + h.attemptsMu.Unlock() + + if !localClient && !allowRemote { + return false, http.StatusForbidden, "remote management disabled" + } + + fail := func() { h.attemptsMu.Lock() - ai := h.failedAttempts[clientIP] - if ai != nil { - if !ai.blockedUntil.IsZero() { - if time.Now().Before(ai.blockedUntil) { - remaining := time.Until(ai.blockedUntil).Round(time.Second) - h.attemptsMu.Unlock() - return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining) - } - // Ban expired, reset state - ai.blockedUntil = time.Time{} - ai.count = 0 - } + aip := h.failedAttempts[clientIP] + if aip == nil { + aip = &attemptInfo{} + h.failedAttempts[clientIP] = aip + } + aip.count++ + aip.lastActivity = time.Now() + if aip.count >= maxFailures { + aip.blockedUntil = time.Now().Add(banDuration) + aip.count = 0 } h.attemptsMu.Unlock() - - if !allowRemote { - return false, http.StatusForbidden, "remote management disabled" - } - - fail = func() { - h.attemptsMu.Lock() - aip := h.failedAttempts[clientIP] - if aip == nil { - aip = &attemptInfo{} - h.failedAttempts[clientIP] = aip - } - aip.count++ - aip.lastActivity = time.Now() - if aip.count >= maxFailures { - aip.blockedUntil = time.Now().Add(banDuration) - aip.count = 0 - } - h.attemptsMu.Unlock() - } } - if secretHash == "" && envSecret == "" { - return false, http.StatusForbidden, "remote management key not set" - } - - if provided == "" { - if !localClient { - fail() - } - return false, http.StatusUnauthorized, "missing management key" - } - - if localClient { - if lp := h.localPassword; lp != "" { - if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 { - return true, 0, "" - } - } - } - - if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 { - if !localClient { - h.attemptsMu.Lock() - if ai := h.failedAttempts[clientIP]; ai != nil { - ai.count = 0 - ai.blockedUntil = time.Time{} - } - h.attemptsMu.Unlock() - } - return true, 0, "" - } - - if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil { - if !localClient { - fail() - } - return false, http.StatusUnauthorized, "invalid management key" - } - - if !localClient { + reset := func() { h.attemptsMu.Lock() if ai := h.failedAttempts[clientIP]; ai != nil { ai.count = 0 @@ -293,6 +251,36 @@ func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, p h.attemptsMu.Unlock() } + if secretHash == "" && envSecret == "" { + return false, http.StatusForbidden, "remote management key not set" + } + + if provided == "" { + fail() + return false, http.StatusUnauthorized, "missing management key" + } + + if localClient { + if lp := h.localPassword; lp != "" { + if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 { + reset() + return true, 0, "" + } + } + } + + if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 { + reset() + return true, 0, "" + } + + if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil { + fail() + return false, http.StatusUnauthorized, "invalid management key" + } + + reset() + return true, 0, "" } diff --git a/internal/api/handlers/management/handler_test.go b/internal/api/handlers/management/handler_test.go new file mode 100644 index 000000000..f3a6086e9 --- /dev/null +++ b/internal/api/handlers/management/handler_test.go @@ -0,0 +1,38 @@ +package management + +import ( + "net/http" + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestAuthenticateManagementKey_LocalhostIPBan_BlocksCorrectKeyDuringBan(t *testing.T) { + h := &Handler{ + cfg: &config.Config{}, + failedAttempts: make(map[string]*attemptInfo), + envSecret: "test-secret", + } + + for i := 0; i < 5; i++ { + allowed, statusCode, errMsg := h.AuthenticateManagementKey("127.0.0.1", true, "wrong-secret") + if allowed { + t.Fatalf("expected auth to be denied at attempt %d", i+1) + } + if statusCode != http.StatusUnauthorized || errMsg != "invalid management key" { + t.Fatalf("unexpected auth failure at attempt %d: status=%d msg=%q", i+1, statusCode, errMsg) + } + } + + allowed, statusCode, errMsg := h.AuthenticateManagementKey("127.0.0.1", true, "test-secret") + if allowed { + t.Fatalf("expected correct key to be denied while banned") + } + if statusCode != http.StatusForbidden { + t.Fatalf("expected forbidden status while banned, got %d", statusCode) + } + if !strings.HasPrefix(errMsg, "IP banned due to too many failed attempts. Try again in") { + t.Fatalf("unexpected banned message: %q", errMsg) + } +} diff --git a/internal/api/redis_queue_protocol.go b/internal/api/redis_queue_protocol.go index 053a99c75..caaba2316 100644 --- a/internal/api/redis_queue_protocol.go +++ b/internal/api/redis_queue_protocol.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/http" "strconv" "strings" @@ -66,10 +67,38 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) { } cmd := strings.ToUpper(strings.TrimSpace(args[0])) + + if cmd != "AUTH" && !authed { + if s.mgmt != nil { + _, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "") + if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") { + _ = writeRedisError(writer, "ERR "+errMsg) + } else { + _ = writeRedisError(writer, "NOAUTH Authentication required.") + } + } else { + _ = writeRedisError(writer, "NOAUTH Authentication required.") + } + if !flush() { + return + } + continue + } + switch cmd { case "AUTH": password, ok := parseAuthPassword(args) if !ok { + if s.mgmt != nil { + _, statusCode, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, "") + if statusCode == http.StatusForbidden && strings.HasPrefix(errMsg, "IP banned due to too many failed attempts") { + _ = writeRedisError(writer, "ERR "+errMsg) + if !flush() { + return + } + continue + } + } _ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command") if !flush() { return @@ -151,10 +180,35 @@ func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) { if addr == nil { return "", false } - host := addr.String() - if h, _, err := net.SplitHostPort(host); err == nil { - host = h + + var host string + switch a := addr.(type) { + case *net.TCPAddr: + if a != nil && a.IP != nil { + if ip4 := a.IP.To4(); ip4 != nil { + host = ip4.String() + } else { + host = a.IP.String() + } + } + default: + host = addr.String() + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + host = strings.TrimSpace(host) + if raw, _, ok := strings.Cut(host, "%"); ok { + host = raw + } + if parsed := net.ParseIP(host); parsed != nil { + if ip4 := parsed.To4(); ip4 != nil { + host = ip4.String() + } else { + host = parsed.String() + } + } } + host = strings.TrimSpace(host) localClient = host == "127.0.0.1" || host == "::1" return host, localClient diff --git a/internal/api/redis_queue_protocol_integration_test.go b/internal/api/redis_queue_protocol_integration_test.go index 18ab0279a..93bfeb866 100644 --- a/internal/api/redis_queue_protocol_integration_test.go +++ b/internal/api/redis_queue_protocol_integration_test.go @@ -15,6 +15,18 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue" ) +type remoteAddrConn struct { + net.Conn + remoteAddr net.Addr +} + +func (c *remoteAddrConn) RemoteAddr() net.Addr { + if c == nil { + return nil + } + return c.remoteAddr +} + func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) { t.Helper() @@ -302,3 +314,163 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) { t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems) } } + +func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) { + const managementPassword = "test-management-password" + + t.Setenv("MANAGEMENT_PASSWORD", managementPassword) + redisqueue.SetEnabled(false) + t.Cleanup(func() { redisqueue.SetEnabled(false) }) + + server := newTestServer(t) + if !server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be true") + } + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { _ = clientConn.Close() }) + t.Cleanup(func() { _ = serverConn.Close() }) + + fakeRemote := &net.TCPAddr{ + IP: net.ParseIP("1.2.3.4"), + Port: 1234, + } + wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote} + + go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn)) + + reader := bufio.NewReader(clientConn) + _ = clientConn.SetDeadline(time.Now().Add(5 * time.Second)) + + for i := 0; i < 5; i++ { + if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write LPOP command: %v", errWrite) + } + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read LPOP NOAUTH error: %v", err) + } else if msg != "NOAUTH Authentication required." { + t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg) + } + } + + if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil { + t.Fatalf("failed to write LPOP command after failures: %v", errWrite) + } + msg, err := readTestRESPError(reader) + if err != nil { + t.Fatalf("failed to read LPOP banned error: %v", err) + } + if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { + t.Fatalf("unexpected LPOP banned error: %q", msg) + } +} + +func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) { + const managementPassword = "test-management-password" + + t.Setenv("MANAGEMENT_PASSWORD", managementPassword) + redisqueue.SetEnabled(false) + t.Cleanup(func() { redisqueue.SetEnabled(false) }) + + server := newTestServer(t) + if !server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be true") + } + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { _ = clientConn.Close() }) + t.Cleanup(func() { _ = serverConn.Close() }) + + fakeRemote := &net.TCPAddr{ + IP: net.ParseIP("1.2.3.4"), + Port: 1234, + } + wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote} + + go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn)) + + reader := bufio.NewReader(clientConn) + _ = clientConn.SetDeadline(time.Now().Add(5 * time.Second)) + + for i := 0; i < 5; i++ { + if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil { + t.Fatalf("failed to write AUTH command: %v", errWrite) + } + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read AUTH error: %v", err) + } else if msg != "ERR invalid management key" { + t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg) + } + } + + for i := 0; i < 2; i++ { + if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil { + t.Fatalf("failed to write AUTH command after failures: %v", errWrite) + } + msg, err := readTestRESPError(reader) + if err != nil { + t.Fatalf("failed to read AUTH banned error: %v", err) + } + if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { + t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg) + } + } + + if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write AUTH command with correct password: %v", errWrite) + } + msg, err := readTestRESPError(reader) + if err != nil { + t.Fatalf("failed to read AUTH banned error for correct password: %v", err) + } + if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { + t.Fatalf("unexpected AUTH banned error for correct password: %q", msg) + } +} + +func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) { + const managementPassword = "test-management-password" + + t.Setenv("MANAGEMENT_PASSWORD", managementPassword) + redisqueue.SetEnabled(false) + t.Cleanup(func() { redisqueue.SetEnabled(false) }) + + server := newTestServer(t) + if !server.managementRoutesEnabled.Load() { + t.Fatalf("expected managementRoutesEnabled to be true") + } + + addr, stop := startRedisMuxListener(t, server) + t.Cleanup(stop) + + conn, errDial := net.DialTimeout("tcp", addr, time.Second) + if errDial != nil { + t.Fatalf("failed to dial redis listener: %v", errDial) + } + t.Cleanup(func() { _ = conn.Close() }) + + reader := bufio.NewReader(conn) + _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) + + for i := 0; i < 5; i++ { + if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil { + t.Fatalf("failed to write AUTH command: %v", errWrite) + } + if msg, err := readTestRESPError(reader); err != nil { + t.Fatalf("failed to read AUTH error: %v", err) + } else if msg != "ERR invalid management key" { + t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg) + } + } + + if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil { + t.Fatalf("failed to write AUTH command with correct password: %v", errWrite) + } + msg, err := readTestRESPError(reader) + if err != nil { + t.Fatalf("failed to read AUTH banned error for correct password: %v", err) + } + if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") { + t.Fatalf("unexpected AUTH banned error for correct password: %q", msg) + } +}