From 39357c67d91e2fd2db5d3f74c39f2858df5e3604 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Sun, 22 Mar 2026 13:02:06 +0100 Subject: [PATCH] fix: simplify authentication handling --- auth/authentication.go | 265 ++++++++++++++++++++---------------- auth/authentication_test.go | 8 +- router/router.go | 12 +- 3 files changed, 154 insertions(+), 131 deletions(-) diff --git a/auth/authentication.go b/auth/authentication.go index 295d67b..560f75d 100644 --- a/auth/authentication.go +++ b/auth/authentication.go @@ -10,15 +10,20 @@ import ( "github.com/gotify/server/v2/model" ) +type authState int + const ( - headerName = "X-Gotify-Key" + authStateSkip authState = iota + authStateForbidden + authStateOk ) +const headerName = "X-Gotify-Key" + // The Database interface for encapsulating database access. type Database interface { GetApplicationByToken(token string) (*model.Application, error) GetClientByToken(token string) (*model.Client, error) - GetPluginConfByToken(token string) (*model.PluginConf, error) GetUserByName(name string) (*model.User, error) GetUserByID(id uint) (*model.User, error) UpdateClientTokensLastUsed(tokens []string, t *time.Time) error @@ -30,72 +35,154 @@ type Auth struct { DB Database } -type authenticate func(tokenID string, user *model.User) (authenticated, success bool, userId uint, err error) - // RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied // with the request. Also the authenticated user must be an administrator. -func (a *Auth) RequireAdmin() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { - if user != nil { - return true, user.Admin, user.ID, nil - } - if token, err := a.DB.GetClientByToken(tokenID); err != nil { - return false, false, 0, err - } else if token != nil { - user, err := a.DB.GetUserByID(token.UserID) - if err != nil { - return false, false, token.UserID, err - } - return true, user.Admin, token.UserID, nil - } - return false, false, 0, nil - }) +func (a *Auth) RequireAdmin(ctx *gin.Context) { + a.evaluateOr401(ctx, a.user(true), a.client(true)) } // RequireClient returns a gin middleware which requires a client token or basic authentication header to be supplied // with the request. -func (a *Auth) RequireClient() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { - if user != nil { - return true, true, user.ID, nil - } - if client, err := a.DB.GetClientByToken(tokenID); err != nil { - return false, false, 0, err - } else if client != nil { - now := time.Now() - if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) { - if err := a.DB.UpdateClientTokensLastUsed([]string{tokenID}, &now); err != nil { - return false, false, 0, err - } - } - return true, true, client.UserID, nil - } - return false, false, 0, nil - }) +func (a *Auth) RequireClient(ctx *gin.Context) { + a.evaluateOr401(ctx, a.user(false), a.client(false)) } // RequireApplicationToken returns a gin middleware which requires an application token to be supplied with the request. -func (a *Auth) RequireApplicationToken() gin.HandlerFunc { - return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) { - if user != nil { - return true, false, 0, nil - } - if app, err := a.DB.GetApplicationByToken(tokenID); err != nil { - return false, false, 0, err - } else if app != nil { - now := time.Now() - if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) { - if err := a.DB.UpdateApplicationTokenLastUsed(tokenID, &now); err != nil { - return false, false, 0, err - } - } - return true, true, app.UserID, nil - } - return false, false, 0, nil - }) +func (a *Auth) RequireApplicationToken(ctx *gin.Context) { + if a.evaluate(ctx, a.application) { + return + } + state, err := a.user(false)(ctx) + if err != nil { + ctx.AbortWithError(500, err) + } + if state != authStateSkip { + // Return to the user that it's valid authentication, but we don't allow user auth for application endpoints. + a.abort403(ctx) + return + } + a.abort401(ctx) } -func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) string { +func (a *Auth) Optional(ctx *gin.Context) { + if !a.evaluate(ctx, a.user(false), a.client(false)) { + RegisterAuthentication(ctx, nil, 0, "") + ctx.Next() + } +} + +func (a *Auth) evaluate(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) bool { + for _, fn := range funcs { + state, err := fn(ctx) + if err != nil { + ctx.AbortWithError(500, err) + return true + } + switch state { + case authStateForbidden: + a.abort403(ctx) + return true + case authStateOk: + ctx.Next() + return true + case authStateSkip: + continue + } + } + return false +} + +func (a *Auth) evaluateOr401(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) { + if !a.evaluate(ctx, funcs...) { + a.abort401(ctx) + } +} + +func (a *Auth) abort401(ctx *gin.Context) { + ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api")) +} + +func (a *Auth) abort403(ctx *gin.Context) { + ctx.AbortWithError(403, errors.New("you are not allowed to access this api")) +} + +func (a *Auth) user(requireAdmin bool) func(ctx *gin.Context) (authState, error) { + return func(ctx *gin.Context) (authState, error) { + if name, pass, ok := ctx.Request.BasicAuth(); ok { + if user, err := a.DB.GetUserByName(name); err != nil { + return authStateSkip, err + } else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) { + RegisterAuthentication(ctx, user, user.ID, "") + + if requireAdmin && !user.Admin { + return authStateForbidden, nil + } + return authStateOk, nil + } + } + return authStateSkip, nil + } +} + +func (a *Auth) client(requireAdmin bool) func(ctx *gin.Context) (authState, error) { + return func(ctx *gin.Context) (authState, error) { + token := a.readTokenFromRequest(ctx) + if token == "" { + return authStateSkip, nil + } + client, err := a.DB.GetClientByToken(token) + if err != nil { + return authStateSkip, err + } + if client == nil { + return authStateSkip, nil + } + RegisterAuthentication(ctx, nil, client.UserID, client.Token) + + now := time.Now() + if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) { + if err := a.DB.UpdateClientTokensLastUsed([]string{client.Token}, &now); err != nil { + return authStateSkip, err + } + } + + if requireAdmin { + if user, err := a.DB.GetUserByID(client.UserID); err != nil { + return authStateSkip, err + } else if !user.Admin { + return authStateForbidden, nil + } + } + + return authStateOk, nil + } +} + +func (a *Auth) application(ctx *gin.Context) (authState, error) { + token := a.readTokenFromRequest(ctx) + if token == "" { + return authStateSkip, nil + } + app, err := a.DB.GetApplicationByToken(token) + if err != nil { + return authStateSkip, err + } + if app == nil { + return authStateSkip, nil + } + RegisterAuthentication(ctx, nil, app.UserID, app.Token) + + now := time.Now() + if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) { + if err := a.DB.UpdateApplicationTokenLastUsed(app.Token, &now); err != nil { + return authStateSkip, err + } + } + + return authStateOk, nil +} + +func (a *Auth) readTokenFromRequest(ctx *gin.Context) string { if token := a.tokenFromQuery(ctx); token != "" { return token } else if token := a.tokenFromXGotifyHeader(ctx); token != "" { @@ -128,67 +215,3 @@ func (a *Auth) tokenFromAuthorizationHeader(ctx *gin.Context) string { return authHeader[len(prefix):] } - -func (a *Auth) userFromBasicAuth(ctx *gin.Context) (*model.User, error) { - if name, pass, ok := ctx.Request.BasicAuth(); ok { - if user, err := a.DB.GetUserByName(name); err != nil { - return nil, err - } else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) { - return user, nil - } - } - return nil, nil -} - -func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc { - return func(ctx *gin.Context) { - token := a.tokenFromQueryOrHeader(ctx) - user, err := a.userFromBasicAuth(ctx) - if err != nil { - ctx.AbortWithError(500, errors.New("an error occurred while authenticating user")) - return - } - - if user != nil || token != "" { - authenticated, ok, userID, err := auth(token, user) - if err != nil { - ctx.AbortWithError(500, errors.New("an error occurred while authenticating user")) - return - } else if ok { - RegisterAuthentication(ctx, user, userID, token) - ctx.Next() - return - } else if authenticated { - ctx.AbortWithError(403, errors.New("you are not allowed to access this api")) - return - } - } - ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api")) - } -} - -func (a *Auth) Optional() gin.HandlerFunc { - return func(ctx *gin.Context) { - token := a.tokenFromQueryOrHeader(ctx) - user, err := a.userFromBasicAuth(ctx) - if err != nil { - RegisterAuthentication(ctx, nil, 0, "") - ctx.Next() - return - } - - if user != nil { - RegisterAuthentication(ctx, user, user.ID, token) - ctx.Next() - return - } else if token != "" { - if tokenClient, err := a.DB.GetClientByToken(token); err == nil && tokenClient != nil { - RegisterAuthentication(ctx, user, tokenClient.UserID, token) - ctx.Next() - return - } - } - RegisterAuthentication(ctx, nil, 0, "") - ctx.Next() - } -} diff --git a/auth/authentication_test.go b/auth/authentication_test.go index 92efff4..092ef98 100644 --- a/auth/authentication_test.go +++ b/auth/authentication_test.go @@ -82,7 +82,7 @@ func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddlewar recorder := httptest.NewRecorder() ctx, _ = gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil) - f()(ctx) + f(ctx) assert.Equal(s.T(), code, recorder.Code) return ctx } @@ -91,7 +91,7 @@ func (s *AuthenticationSuite) TestNothingProvided() { recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest("GET", "/", nil) - s.auth.RequireApplicationToken()(ctx) + s.auth.RequireApplicationToken(ctx) assert.Equal(s.T(), 401, recorder.Code) } @@ -215,9 +215,9 @@ func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddlewa ctx, _ = gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest("GET", "/", nil) ctx.Request.Header.Set(key, value) - f()(ctx) + f(ctx) assert.Equal(s.T(), code, recorder.Code) return ctx } -type fMiddleware func() gin.HandlerFunc +type fMiddleware gin.HandlerFunc diff --git a/router/router.go b/router/router.go index 8b19137..2c166b3 100644 --- a/router/router.go +++ b/router/router.go @@ -120,8 +120,8 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co g.Use(cors.New(auth.CorsConfig(conf))) { - g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins) - pluginRoute := g.Group("/plugin/", authentication.RequireClient()) + g.GET("/plugin", authentication.RequireClient, pluginHandler.GetPlugins) + pluginRoute := g.Group("/plugin/", authentication.RequireClient) { pluginRoute.GET("/:id/config", pluginHandler.GetConfig) pluginRoute.POST("/:id/config", pluginHandler.UpdateConfig) @@ -131,7 +131,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co } } - g.Group("/user").Use(authentication.Optional()).POST("", userHandler.CreateUser) + g.Group("/user").Use(authentication.Optional).POST("", userHandler.CreateUser) g.OPTIONS("/*any") @@ -150,11 +150,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co ctx.JSON(200, vInfo) }) - g.Group("/").Use(authentication.RequireApplicationToken()).POST("/message", messageHandler.CreateMessage) + g.Group("/").Use(authentication.RequireApplicationToken).POST("/message", messageHandler.CreateMessage) clientAuth := g.Group("") { - clientAuth.Use(authentication.RequireClient()) + clientAuth.Use(authentication.RequireClient) app := clientAuth.Group("/application") { app.GET("", applicationHandler.GetApplications) @@ -206,7 +206,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co authAdmin := g.Group("/user") { - authAdmin.Use(authentication.RequireAdmin()) + authAdmin.Use(authentication.RequireAdmin) authAdmin.GET("", userHandler.GetUsers)