mirror of
https://github.com/0xJacky/nginx-ui.git
synced 2026-05-07 22:41:41 +08:00
- Introduced `WebSocketTrustedOrigins` setting in `app.example.ini` and corresponding documentation. - Refactored WebSocket origin checks across multiple API endpoints to utilize the new middleware for improved security. - Added tests for the new origin validation logic to ensure proper handling of trusted origins and node secret requests.
331 lines
7.6 KiB
Go
331 lines
7.6 KiB
Go
package cluster
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/0xJacky/Nginx-UI/internal/analytic"
|
|
"github.com/0xJacky/Nginx-UI/internal/helper"
|
|
"github.com/0xJacky/Nginx-UI/internal/kernel"
|
|
"github.com/0xJacky/Nginx-UI/internal/middleware"
|
|
"github.com/0xJacky/Nginx-UI/model"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/uozi-tech/cosy/logger"
|
|
)
|
|
|
|
// WebSocketMessage represents the structure of messages sent to the client
|
|
type WebSocketMessage struct {
|
|
Event string `json:"event"`
|
|
Data interface{} `json:"data"`
|
|
}
|
|
|
|
// Client represents a WebSocket client connection for cluster node monitoring
|
|
type Client struct {
|
|
conn *websocket.Conn
|
|
send chan WebSocketMessage
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// Hub maintains the set of active clients and broadcasts messages to them
|
|
type Hub struct {
|
|
clients map[*Client]bool
|
|
broadcast chan WebSocketMessage
|
|
register chan *Client
|
|
unregister chan *Client
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
var (
|
|
hub *Hub
|
|
hubOnce sync.Once
|
|
)
|
|
|
|
// GetHub returns the singleton hub instance
|
|
func GetHub() *Hub {
|
|
hubOnce.Do(func() {
|
|
hub = &Hub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan WebSocketMessage, 1024), // Increased buffer size
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
}
|
|
go hub.run()
|
|
})
|
|
return hub
|
|
}
|
|
|
|
// run handles the main hub loop
|
|
func (h *Hub) run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.mutex.Lock()
|
|
h.clients[client] = true
|
|
h.mutex.Unlock()
|
|
logger.Debug("Cluster node client connected, total clients:", len(h.clients))
|
|
|
|
case client := <-h.unregister:
|
|
h.mutex.Lock()
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
close(client.send)
|
|
}
|
|
h.mutex.Unlock()
|
|
logger.Debug("Cluster node client disconnected, total clients:", len(h.clients))
|
|
|
|
case message := <-h.broadcast:
|
|
h.mutex.RLock()
|
|
deadClients := make([]*Client, 0)
|
|
for client := range h.clients {
|
|
select {
|
|
case client.send <- message:
|
|
case <-time.After(100 * time.Millisecond):
|
|
// Client is too slow, mark for removal
|
|
logger.Debug("Client send channel timeout, marking for removal")
|
|
deadClients = append(deadClients, client)
|
|
default:
|
|
// Channel is full, mark for removal
|
|
logger.Debug("Client send channel full, marking for removal")
|
|
deadClients = append(deadClients, client)
|
|
}
|
|
}
|
|
h.mutex.RUnlock()
|
|
|
|
// Clean up dead clients
|
|
if len(deadClients) > 0 {
|
|
h.mutex.Lock()
|
|
for _, client := range deadClients {
|
|
if _, ok := h.clients[client]; ok {
|
|
close(client.send)
|
|
delete(h.clients, client)
|
|
client.cancel() // Trigger client cleanup
|
|
}
|
|
}
|
|
h.mutex.Unlock()
|
|
logger.Info("Cleaned up slow/unresponsive clients", "count", len(deadClients))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastMessage sends a message to all connected clients
|
|
func (h *Hub) BroadcastMessage(event string, data any) {
|
|
message := WebSocketMessage{
|
|
Event: event,
|
|
Data: data,
|
|
}
|
|
select {
|
|
case h.broadcast <- message:
|
|
default:
|
|
logger.Warn("Cluster node broadcast channel full, message dropped")
|
|
}
|
|
}
|
|
|
|
// WebSocket upgrader configuration
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: middleware.CheckWebSocketOrigin,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
}
|
|
|
|
type respNode struct {
|
|
*model.Node
|
|
Status bool `json:"status"`
|
|
}
|
|
|
|
// GetAllEnabledNodeWS handles WebSocket connections for real-time node monitoring
|
|
func GetAllEnabledNodeWS(c *gin.Context) {
|
|
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
logger.Error("Failed to upgrade connection:", err)
|
|
return
|
|
}
|
|
defer ws.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
client := &Client{
|
|
conn: ws,
|
|
send: make(chan WebSocketMessage, 1024), // Increased buffer size
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
hub := GetHub()
|
|
hub.register <- client
|
|
|
|
// Start goroutines for handling node monitoring
|
|
go client.handleNodeMonitoring()
|
|
|
|
// Start write and read pumps
|
|
go client.writePump()
|
|
client.readPump()
|
|
}
|
|
|
|
// handleNodeMonitoring monitors node status and sends updates
|
|
func (c *Client) handleNodeMonitoring() {
|
|
interval := 10 * time.Second
|
|
heartbeatInterval := 30 * time.Second
|
|
|
|
getNodeData := func() (interface{}, bool) {
|
|
// Query nodes directly from database
|
|
var nodes []model.Node
|
|
err := model.UseDB().Where("enabled = ?", true).Find(&nodes).Error
|
|
if err != nil {
|
|
logger.Error("Failed to query nodes:", err)
|
|
return nil, false
|
|
}
|
|
|
|
// Transform nodes to response format
|
|
var result []respNode
|
|
for _, node := range nodes {
|
|
result = append(result, respNode{
|
|
Node: &node,
|
|
Status: analytic.GetNode(&node).Status,
|
|
})
|
|
}
|
|
|
|
return result, true
|
|
}
|
|
|
|
getHash := func(data interface{}) string {
|
|
bytes, _ := json.Marshal(data)
|
|
hash := sha256.New()
|
|
hash.Write(bytes)
|
|
hashSum := hash.Sum(nil)
|
|
return hex.EncodeToString(hashSum)
|
|
}
|
|
|
|
var dataHash string
|
|
|
|
// Send initial data
|
|
data, ok := getNodeData()
|
|
if ok {
|
|
dataHash = getHash(data)
|
|
c.sendMessage("message", data)
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
heartbeatTicker := time.NewTicker(heartbeatInterval)
|
|
defer ticker.Stop()
|
|
defer heartbeatTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
data, ok := getNodeData()
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
newHash := getHash(data)
|
|
if dataHash != newHash {
|
|
dataHash = newHash
|
|
c.sendMessage("message", data)
|
|
}
|
|
|
|
case <-heartbeatTicker.C:
|
|
c.sendMessage("heartbeat", "")
|
|
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// sendMessage sends a message to the client with timeout and better error handling
|
|
func (c *Client) sendMessage(event string, data any) {
|
|
message := WebSocketMessage{
|
|
Event: event,
|
|
Data: data,
|
|
}
|
|
|
|
select {
|
|
case c.send <- message:
|
|
case <-time.After(5 * time.Second):
|
|
logger.Warn("Client send channel full, message dropped after timeout", "event", event)
|
|
// Force disconnect slow clients to prevent resource leakage
|
|
c.cancel()
|
|
default:
|
|
logger.Warn("Client send channel full, message dropped immediately", "event", event)
|
|
// For non-critical messages, we can drop them immediately
|
|
if event != "heartbeat" {
|
|
logger.Info("Dropping non-critical message due to full channel", "event", event)
|
|
}
|
|
}
|
|
}
|
|
|
|
// writePump pumps messages from the hub to the websocket connection
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
if err := c.conn.WriteJSON(message); err != nil {
|
|
logger.Error("Error writing message to websocket:", err)
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
case <-kernel.Context.Done():
|
|
return
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// readPump pumps messages from the websocket connection to the hub
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
hub := GetHub()
|
|
hub.unregister <- c
|
|
c.conn.Close()
|
|
c.cancel()
|
|
}()
|
|
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
return nil
|
|
})
|
|
|
|
go func() {
|
|
for {
|
|
_, _, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if helper.IsUnexpectedWebsocketError(err) {
|
|
logger.Error("Websocket error:", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-kernel.Context.Done():
|
|
return
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|