diff --git a/api/analytic/analytic.go b/api/analytic/analytic.go index e0fbd655..2d0ac8f5 100644 --- a/api/analytic/analytic.go +++ b/api/analytic/analytic.go @@ -9,6 +9,7 @@ import ( "github.com/0xJacky/Nginx-UI/internal/analytic" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/version" "github.com/shirou/gopsutil/v4/cpu" "github.com/shirou/gopsutil/v4/host" @@ -22,9 +23,7 @@ import ( func Analytic(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) diff --git a/api/analytic/nodes.go b/api/analytic/nodes.go index a6d9eb89..cb502c02 100644 --- a/api/analytic/nodes.go +++ b/api/analytic/nodes.go @@ -1,12 +1,12 @@ package analytic import ( - "net/http" "time" "github.com/0xJacky/Nginx-UI/internal/analytic" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/version" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -16,9 +16,7 @@ import ( func GetNodeStat(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) @@ -93,9 +91,7 @@ func GetNodeStat(c *gin.Context) { func GetNodesAnalytic(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) diff --git a/api/certificate/issue.go b/api/certificate/issue.go index adc3f0da..8318191b 100644 --- a/api/certificate/issue.go +++ b/api/certificate/issue.go @@ -1,10 +1,9 @@ package certificate import ( - "net/http" - "github.com/0xJacky/Nginx-UI/internal/cert" "github.com/0xJacky/Nginx-UI/internal/helper" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/translation" "github.com/0xJacky/Nginx-UI/model" "github.com/0xJacky/Nginx-UI/query" @@ -32,9 +31,7 @@ type IssueCertResponse struct { func IssueCert(c *gin.Context) { name := c.Param("name") var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket diff --git a/api/certificate/revoke.go b/api/certificate/revoke.go index d33e0318..3d354e91 100644 --- a/api/certificate/revoke.go +++ b/api/certificate/revoke.go @@ -1,9 +1,8 @@ package certificate import ( - "net/http" - "github.com/0xJacky/Nginx-UI/internal/cert" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/translation" "github.com/0xJacky/Nginx-UI/query" "github.com/gin-gonic/gin" @@ -41,9 +40,7 @@ func RevokeCert(c *gin.Context) { id := cast.ToUint64(c.Param("id")) var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket diff --git a/api/cluster/websocket.go b/api/cluster/websocket.go index 93642454..2a3a6c22 100644 --- a/api/cluster/websocket.go +++ b/api/cluster/websocket.go @@ -5,13 +5,13 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" - "net/http" "sync" "time" "github.com/0xJacky/Nginx-UI/internal/analytic" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/model" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -129,9 +129,7 @@ func (h *Hub) BroadcastMessage(event string, data any) { // WebSocket upgrader configuration var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, } diff --git a/api/event/websocket.go b/api/event/websocket.go index 80bb3fbd..c50d4f1c 100644 --- a/api/event/websocket.go +++ b/api/event/websocket.go @@ -3,13 +3,13 @@ package event import ( "context" "encoding/json" - "net/http" "sync" "time" "github.com/0xJacky/Nginx-UI/internal/event" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/uozi-tech/cosy/logger" @@ -133,9 +133,7 @@ func (h *Hub) run() { // WebSocket upgrader configuration var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, } diff --git a/api/geolite/download.go b/api/geolite/download.go index 3035a566..c4006567 100644 --- a/api/geolite/download.go +++ b/api/geolite/download.go @@ -1,9 +1,8 @@ package geolite import ( - "net/http" - "github.com/0xJacky/Nginx-UI/internal/geolite" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/uozi-tech/cosy/logger" @@ -23,9 +22,7 @@ type DownloadProgressResp struct { func DownloadGeoLiteDB(c *gin.Context) { var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // Upgrade HTTP to WebSocket diff --git a/api/llm/code_completion.go b/api/llm/code_completion.go index 9506890a..aef13dfd 100644 --- a/api/llm/code_completion.go +++ b/api/llm/code_completion.go @@ -8,6 +8,7 @@ import ( "github.com/0xJacky/Nginx-UI/api" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/llm" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/settings" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -25,9 +26,7 @@ func CodeCompletion(c *gin.Context) { } var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { diff --git a/api/nginx/websocket.go b/api/nginx/websocket.go index c48ec234..8fa5e8ed 100644 --- a/api/nginx/websocket.go +++ b/api/nginx/websocket.go @@ -2,12 +2,12 @@ package nginx import ( "context" - "net/http" "sync" "time" "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/performance" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -184,9 +184,7 @@ func (h *PerformanceHub) broadcastPerformanceData() { // WebSocket upgrader configuration var nginxPerformanceUpgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, } diff --git a/api/nginx_log/websocket.go b/api/nginx_log/websocket.go index d94dc0a1..9e8d1df7 100644 --- a/api/nginx_log/websocket.go +++ b/api/nginx_log/websocket.go @@ -3,11 +3,11 @@ package nginx_log import ( "encoding/json" "io" - "net/http" "os" "runtime" "github.com/0xJacky/Nginx-UI/internal/helper" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/nginx" "github.com/0xJacky/Nginx-UI/internal/nginx_log" "github.com/0xJacky/Nginx-UI/internal/nginx_log/utils" @@ -171,9 +171,7 @@ func handleLogControl(ws *websocket.Conn, controlChan chan controlStruct, errCha // Log handles websocket connection for real-time log viewing func Log(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) diff --git a/api/sites/websocket.go b/api/sites/websocket.go index f43cd2d3..d32e8250 100644 --- a/api/sites/websocket.go +++ b/api/sites/websocket.go @@ -1,10 +1,10 @@ package sites import ( - "net/http" "sync" "github.com/0xJacky/Nginx-UI/internal/helper" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/sitecheck" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -37,9 +37,7 @@ type PongMessage struct { } var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // WSManager WebSocket connection manager diff --git a/api/system/self_check.go b/api/system/self_check.go index d80d2a11..e66ed75b 100644 --- a/api/system/self_check.go +++ b/api/system/self_check.go @@ -3,6 +3,7 @@ package system import ( "net/http" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/gorilla/websocket" "github.com/uozi-tech/cosy/logger" @@ -24,9 +25,7 @@ func SelfCheckFix(c *gin.Context) { func CheckWebSocket(c *gin.Context) { var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { diff --git a/api/system/upgrade.go b/api/system/upgrade.go index f438d2cf..4c6df5e6 100644 --- a/api/system/upgrade.go +++ b/api/system/upgrade.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/0xJacky/Nginx-UI/internal/helper" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/upgrader" "github.com/0xJacky/Nginx-UI/internal/version" "github.com/gin-gonic/gin" @@ -50,9 +51,7 @@ type CoreUpgradeResp struct { func PerformCoreUpgrade(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) diff --git a/api/terminal/pty.go b/api/terminal/pty.go index 9964d254..eaf02fb5 100644 --- a/api/terminal/pty.go +++ b/api/terminal/pty.go @@ -1,18 +1,16 @@ package terminal import ( + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/pty" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/uozi-tech/cosy/logger" - "net/http" ) func Pty(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // upgrade http to websocket ws, err := upGrader.Upgrade(c.Writer, c.Request, nil) diff --git a/api/upstream/upstream.go b/api/upstream/upstream.go index abd1b5ac..614c9cf1 100644 --- a/api/upstream/upstream.go +++ b/api/upstream/upstream.go @@ -8,6 +8,7 @@ import ( "github.com/0xJacky/Nginx-UI/internal/helper" "github.com/0xJacky/Nginx-UI/internal/kernel" + "github.com/0xJacky/Nginx-UI/internal/middleware" "github.com/0xJacky/Nginx-UI/internal/upstream" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -43,9 +44,7 @@ func GetUpstreamDefinitions(c *gin.Context) { // AvailabilityWebSocket handles WebSocket connections for real-time availability monitoring func AvailabilityWebSocket(c *gin.Context) { var upGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, + CheckOrigin: middleware.CheckWebSocketOrigin, } // Upgrade HTTP to WebSocket diff --git a/app.example.ini b/app.example.ini index 71052751..713bb238 100644 --- a/app.example.ini +++ b/app.example.ini @@ -53,6 +53,7 @@ Secret = [http] GithubProxy = https://mirror.ghproxy.com/ InsecureSkipVerify = false +WebSocketTrustedOrigins = [logrotate] Enabled = false diff --git a/app/src/pinia/moudule/user.ts b/app/src/pinia/moudule/user.ts index 2c1de845..2dbea2b9 100644 --- a/app/src/pinia/moudule/user.ts +++ b/app/src/pinia/moudule/user.ts @@ -6,30 +6,48 @@ import user from '@/api/user' export const useUserStore = defineStore('user', () => { const cookies = useCookies(['nginx-ui']) + function getCookieOptions(maxAge: number) { + return { + path: '/', + maxAge, + sameSite: 'lax' as const, + secure: window.location.protocol === 'https:', + } + } + const token = ref('') const shortToken = ref('') watch(token, v => { - cookies.set('token', v, { maxAge: 86400 }) + if (v) + cookies.set('token', v, getCookieOptions(86400)) + else + cookies.remove('token', { path: '/' }) }) watch(shortToken, v => { - cookies.set('short_token', v, { maxAge: 86400 }) + if (v) + cookies.set('short_token', v, getCookieOptions(86400)) + else + cookies.remove('short_token', { path: '/' }) }) const secureSessionId = ref('') watch(secureSessionId, v => { - cookies.set('secure_session_id', v, { maxAge: 60 * 3 }) + if (v) + cookies.set('secure_session_id', v, getCookieOptions(60 * 3)) + else + cookies.remove('secure_session_id', { path: '/' }) }) function handleCookieChange({ name, value }: CookieChangeOptions) { if (name === 'token') - token.value = value + token.value = value || '' else if (name === 'short_token') - shortToken.value = value + shortToken.value = value || '' else if (name === 'secure_session_id') - secureSessionId.value = value + secureSessionId.value = value || '' } cookies.addChangeListener(handleCookieChange) diff --git a/docs/guide/config-server.md b/docs/guide/config-server.md index e74a4a84..9d172070 100644 --- a/docs/guide/config-server.md +++ b/docs/guide/config-server.md @@ -268,3 +268,15 @@ Deprecated in `v2.0.0-beta.37`, please use `Http.InsecureSkipVerify` instead. ::: This option is used to skip the verification of the certificate of servers when Nginx UI sends requests to them. + +## Http.WebSocketTrustedOrigins + +- Type: `[]string` +- Default: empty +- Example: `http://localhost:5173,https://admin.example.com` + +This option allows additional trusted browser origins for authenticated WebSocket connections. + +Use it when Nginx UI is accessed through a reverse proxy with a different public origin, through multiple management domains, or during local development where the frontend and backend run on different ports. + +Keep this list as small as possible. Same-origin WebSocket requests do not need to be added here. diff --git a/docs/zh_CN/guide/config-server.md b/docs/zh_CN/guide/config-server.md index ca0ef575..7eda46ad 100644 --- a/docs/zh_CN/guide/config-server.md +++ b/docs/zh_CN/guide/config-server.md @@ -252,3 +252,15 @@ Nginx UI 将不会创建系统初始的 acme 用户,这意味着您无法在 ::: 此选项用于配置 Nginx UI 服务器在与其他服务器建立 TLS 连接时是否跳过证书验证。 + +## Http.WebSocketTrustedOrigins + +- 类型: `[]string` +- 默认值: 空 +- 示例: `http://localhost:5173,https://admin.example.com` + +此选项用于为已认证的 WebSocket 连接额外声明可信浏览器来源。 + +当 Nginx UI 通过带有不同公网域名的反向代理访问、需要同时支持多个管理域名,或本地开发时前后端运行在不同端口时,可以配置该选项。 + +请尽量保持列表最小化。对于同源的 WebSocket 请求,不需要额外加入这里。 diff --git a/internal/middleware/websocket_origin.go b/internal/middleware/websocket_origin.go new file mode 100644 index 00000000..a31d7aa7 --- /dev/null +++ b/internal/middleware/websocket_origin.go @@ -0,0 +1,187 @@ +package middleware + +import ( + "net" + "net/http" + "net/url" + "strings" + + "github.com/0xJacky/Nginx-UI/settings" +) + +// CheckWebSocketOrigin validates browser origins for WebSocket upgrade requests. +// Non-browser requests are only allowed for trusted node-to-node traffic. +func CheckWebSocketOrigin(r *http.Request) bool { + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + return isTrustedNodeRequest(r) + } + + if requestOrigin, ok := getRequestOrigin(r); ok && sameOrigin(origin, requestOrigin) { + return true + } + + for _, allowedOrigin := range settings.HTTPSettings.WebSocketTrustedOrigins { + if sameOrigin(origin, allowedOrigin) { + return true + } + } + + return false +} + +func isTrustedNodeRequest(r *http.Request) bool { + secret := strings.TrimSpace(r.Header.Get("X-Node-Secret")) + if secret == "" { + secret = strings.TrimSpace(r.URL.Query().Get("node_secret")) + } + + return secret != "" && secret == settings.NodeSettings.Secret +} + +func getRequestOrigin(r *http.Request) (string, bool) { + scheme := getForwardedParam(r.Header.Get("Forwarded"), "proto") + host := getForwardedParam(r.Header.Get("Forwarded"), "host") + + if host == "" { + host = firstHeaderValue(r.Header.Get("X-Forwarded-Host")) + } + if scheme == "" { + scheme = firstHeaderValue(r.Header.Get("X-Forwarded-Proto")) + } + if host == "" { + host = strings.TrimSpace(r.Host) + } + if scheme == "" { + if r.TLS != nil { + scheme = "https" + } else { + scheme = "http" + } + } + + return buildNormalizedOrigin(scheme, host) +} + +func sameOrigin(left, right string) bool { + normalizedLeft, ok := normalizeOrigin(left) + if !ok { + return false + } + + normalizedRight, ok := normalizeOrigin(right) + if !ok { + return false + } + + return normalizedLeft == normalizedRight +} + +func normalizeOrigin(raw string) (string, bool) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || u.Host == "" { + return "", false + } + + scheme, ok := normalizeScheme(u.Scheme) + if !ok { + return "", false + } + + host := normalizeHost(u.Host, scheme) + if host == "" { + return "", false + } + + return scheme + "://" + host, true +} + +func buildNormalizedOrigin(rawScheme, rawHost string) (string, bool) { + scheme, ok := normalizeScheme(rawScheme) + if !ok { + return "", false + } + + host := normalizeHost(rawHost, scheme) + if host == "" { + return "", false + } + + return scheme + "://" + host, true +} + +func normalizeScheme(scheme string) (string, bool) { + switch strings.ToLower(strings.TrimSpace(scheme)) { + case "http", "ws": + return "http", true + case "https", "wss": + return "https", true + default: + return "", false + } +} + +func normalizeHost(host, scheme string) string { + host = firstHeaderValue(host) + if host == "" { + return "" + } + + u, err := url.Parse("//" + host) + if err != nil || u.Hostname() == "" { + return "" + } + + hostname := strings.ToLower(u.Hostname()) + port := u.Port() + + if port == defaultPortForScheme(scheme) { + port = "" + } + + if port != "" { + return net.JoinHostPort(hostname, port) + } + + if strings.Contains(hostname, ":") { + return "[" + hostname + "]" + } + + return hostname +} + +func defaultPortForScheme(scheme string) string { + switch scheme { + case "https": + return "443" + default: + return "80" + } +} + +func firstHeaderValue(value string) string { + if value == "" { + return "" + } + + parts := strings.Split(value, ",") + return strings.TrimSpace(parts[0]) +} + +func getForwardedParam(forwardedValue, key string) string { + if forwardedValue == "" { + return "" + } + + firstEntry := firstHeaderValue(forwardedValue) + for _, part := range strings.Split(firstEntry, ";") { + name, value, ok := strings.Cut(strings.TrimSpace(part), "=") + if !ok || !strings.EqualFold(name, key) { + continue + } + + return strings.Trim(strings.TrimSpace(value), "\"") + } + + return "" +} diff --git a/internal/middleware/websocket_origin_test.go b/internal/middleware/websocket_origin_test.go new file mode 100644 index 00000000..a5ed7db0 --- /dev/null +++ b/internal/middleware/websocket_origin_test.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "crypto/tls" + "net/http/httptest" + "testing" + + "github.com/0xJacky/Nginx-UI/settings" + "github.com/stretchr/testify/assert" +) + +func TestCheckWebSocketOrigin(t *testing.T) { + originalOrigins := settings.HTTPSettings.WebSocketTrustedOrigins + originalSecret := settings.NodeSettings.Secret + + t.Cleanup(func() { + settings.HTTPSettings.WebSocketTrustedOrigins = originalOrigins + settings.NodeSettings.Secret = originalSecret + }) + + t.Run("allows same origin requests", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = nil + settings.NodeSettings.Secret = "" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws", nil) + req.Host = "admin.example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set("Origin", "https://admin.example.com:443") + + assert.True(t, CheckWebSocketOrigin(req)) + }) + + t.Run("allows reverse proxy forwarded origin", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = nil + settings.NodeSettings.Secret = "" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws", nil) + req.Host = "127.0.0.1:9000" + req.Header.Set("Origin", "https://panel.example.com") + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "panel.example.com") + + assert.True(t, CheckWebSocketOrigin(req)) + }) + + t.Run("allows configured trusted origins", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = []string{"http://localhost:5173/"} + settings.NodeSettings.Secret = "" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws", nil) + req.Host = "127.0.0.1:9000" + req.Header.Set("Origin", "http://localhost:5173") + + assert.True(t, CheckWebSocketOrigin(req)) + }) + + t.Run("allows node secret requests without origin", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = nil + settings.NodeSettings.Secret = "node-secret" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws", nil) + req.Header.Set("X-Node-Secret", "node-secret") + + assert.True(t, CheckWebSocketOrigin(req)) + }) + + t.Run("rejects cross site requests", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = nil + settings.NodeSettings.Secret = "" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws", nil) + req.Host = "admin.example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set("Origin", "https://evil.example.com") + + assert.False(t, CheckWebSocketOrigin(req)) + }) + + t.Run("rejects missing origin without trusted node secret", func(t *testing.T) { + settings.HTTPSettings.WebSocketTrustedOrigins = nil + settings.NodeSettings.Secret = "node-secret" + + req := httptest.NewRequest("GET", "http://127.0.0.1/ws?token=abc123", nil) + + assert.False(t, CheckWebSocketOrigin(req)) + }) +} diff --git a/mcp/router.go b/mcp/router.go index d55eb8b2..ac2232d9 100644 --- a/mcp/router.go +++ b/mcp/router.go @@ -11,7 +11,7 @@ func InitRouter(r *gin.Engine) { func(c *gin.Context) { mcp.ServeHTTP(c) }) - r.Any("/mcp_message", middleware.IPWhiteList(), + r.Any("/mcp_message", middleware.IPWhiteList(), middleware.AuthRequired(), func(c *gin.Context) { mcp.ServeHTTP(c) }) diff --git a/mcp/router_test.go b/mcp/router_test.go new file mode 100644 index 00000000..75e831f7 --- /dev/null +++ b/mcp/router_test.go @@ -0,0 +1,35 @@ +package mcp + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/0xJacky/Nginx-UI/settings" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestMCPEndpointsRequireAuthentication(t *testing.T) { + gin.SetMode(gin.TestMode) + + originalIPWhiteList := settings.AuthSettings.IPWhiteList + t.Cleanup(func() { + settings.AuthSettings.IPWhiteList = originalIPWhiteList + }) + + settings.AuthSettings.IPWhiteList = nil + + router := gin.New() + InitRouter(router) + + for _, endpoint := range []string{"/mcp", "/mcp_message"} { + req := httptest.NewRequest(http.MethodPost, endpoint, nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"Authorization failed"}`, w.Body.String()) + } +} diff --git a/settings/http.go b/settings/http.go index b5f183f3..edd29f55 100644 --- a/settings/http.go +++ b/settings/http.go @@ -1,8 +1,11 @@ package settings type HTTP struct { - GithubProxy string `json:"github_proxy" binding:"omitempty,url"` - InsecureSkipVerify bool `json:"insecure_skip_verify" protected:"true"` + GithubProxy string `json:"github_proxy" binding:"omitempty,url"` + InsecureSkipVerify bool `json:"insecure_skip_verify" protected:"true"` + WebSocketTrustedOrigins []string `json:"websocket_trusted_origins" binding:"omitempty,dive,url" env:"WEBSOCKET_TRUSTED_ORIGINS"` } -var HTTPSettings = &HTTP{} +var HTTPSettings = &HTTP{ + WebSocketTrustedOrigins: []string{}, +} diff --git a/settings/settings_test.go b/settings/settings_test.go index a7bb4ab1..6bd3a185 100644 --- a/settings/settings_test.go +++ b/settings/settings_test.go @@ -63,6 +63,7 @@ func TestSetup(t *testing.T) { // Http _ = os.Setenv("NGINX_UI_HTTP_GITHUB_PROXY", "http://proxy.example.com") _ = os.Setenv("NGINX_UI_HTTP_INSECURE_SKIP_VERIFY", "true") + _ = os.Setenv("NGINX_UI_HTTP_WEBSOCKET_TRUSTED_ORIGINS", "http://localhost:5173,https://admin.example.com") // Logrotate _ = os.Setenv("NGINX_UI_LOGROTATE_ENABLED", "true") @@ -155,6 +156,7 @@ func TestSetup(t *testing.T) { // Http assert.Equal(t, "http://proxy.example.com", HTTPSettings.GithubProxy) assert.Equal(t, true, HTTPSettings.InsecureSkipVerify) + assert.Equal(t, []string{"http://localhost:5173", "https://admin.example.com"}, HTTPSettings.WebSocketTrustedOrigins) // Logrotate assert.Equal(t, true, LogrotateSettings.Enabled)