fix: simplify authentication handling

This commit is contained in:
Jannis Mattheis
2026-03-22 13:02:06 +01:00
parent 143438055d
commit 39357c67d9
3 changed files with 154 additions and 131 deletions

View File

@@ -10,15 +10,20 @@ import (
"github.com/gotify/server/v2/model"
)
type authState int
const (
headerName = "X-Gotify-Key"
authStateSkip authState = iota
authStateForbidden
authStateOk
)
const headerName = "X-Gotify-Key"
// The Database interface for encapsulating database access.
type Database interface {
GetApplicationByToken(token string) (*model.Application, error)
GetClientByToken(token string) (*model.Client, error)
GetPluginConfByToken(token string) (*model.PluginConf, error)
GetUserByName(name string) (*model.User, error)
GetUserByID(id uint) (*model.User, error)
UpdateClientTokensLastUsed(tokens []string, t *time.Time) error
@@ -30,72 +35,154 @@ type Auth struct {
DB Database
}
type authenticate func(tokenID string, user *model.User) (authenticated, success bool, userId uint, err error)
// RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied
// with the request. Also the authenticated user must be an administrator.
func (a *Auth) RequireAdmin() gin.HandlerFunc {
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
if user != nil {
return true, user.Admin, user.ID, nil
}
if token, err := a.DB.GetClientByToken(tokenID); err != nil {
return false, false, 0, err
} else if token != nil {
user, err := a.DB.GetUserByID(token.UserID)
if err != nil {
return false, false, token.UserID, err
}
return true, user.Admin, token.UserID, nil
}
return false, false, 0, nil
})
func (a *Auth) RequireAdmin(ctx *gin.Context) {
a.evaluateOr401(ctx, a.user(true), a.client(true))
}
// RequireClient returns a gin middleware which requires a client token or basic authentication header to be supplied
// with the request.
func (a *Auth) RequireClient() gin.HandlerFunc {
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
if user != nil {
return true, true, user.ID, nil
}
if client, err := a.DB.GetClientByToken(tokenID); err != nil {
return false, false, 0, err
} else if client != nil {
now := time.Now()
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateClientTokensLastUsed([]string{tokenID}, &now); err != nil {
return false, false, 0, err
}
}
return true, true, client.UserID, nil
}
return false, false, 0, nil
})
func (a *Auth) RequireClient(ctx *gin.Context) {
a.evaluateOr401(ctx, a.user(false), a.client(false))
}
// RequireApplicationToken returns a gin middleware which requires an application token to be supplied with the request.
func (a *Auth) RequireApplicationToken() gin.HandlerFunc {
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
if user != nil {
return true, false, 0, nil
}
if app, err := a.DB.GetApplicationByToken(tokenID); err != nil {
return false, false, 0, err
} else if app != nil {
now := time.Now()
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateApplicationTokenLastUsed(tokenID, &now); err != nil {
return false, false, 0, err
}
}
return true, true, app.UserID, nil
}
return false, false, 0, nil
})
func (a *Auth) RequireApplicationToken(ctx *gin.Context) {
if a.evaluate(ctx, a.application) {
return
}
state, err := a.user(false)(ctx)
if err != nil {
ctx.AbortWithError(500, err)
}
if state != authStateSkip {
// Return to the user that it's valid authentication, but we don't allow user auth for application endpoints.
a.abort403(ctx)
return
}
a.abort401(ctx)
}
func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) string {
func (a *Auth) Optional(ctx *gin.Context) {
if !a.evaluate(ctx, a.user(false), a.client(false)) {
RegisterAuthentication(ctx, nil, 0, "")
ctx.Next()
}
}
func (a *Auth) evaluate(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) bool {
for _, fn := range funcs {
state, err := fn(ctx)
if err != nil {
ctx.AbortWithError(500, err)
return true
}
switch state {
case authStateForbidden:
a.abort403(ctx)
return true
case authStateOk:
ctx.Next()
return true
case authStateSkip:
continue
}
}
return false
}
func (a *Auth) evaluateOr401(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) {
if !a.evaluate(ctx, funcs...) {
a.abort401(ctx)
}
}
func (a *Auth) abort401(ctx *gin.Context) {
ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api"))
}
func (a *Auth) abort403(ctx *gin.Context) {
ctx.AbortWithError(403, errors.New("you are not allowed to access this api"))
}
func (a *Auth) user(requireAdmin bool) func(ctx *gin.Context) (authState, error) {
return func(ctx *gin.Context) (authState, error) {
if name, pass, ok := ctx.Request.BasicAuth(); ok {
if user, err := a.DB.GetUserByName(name); err != nil {
return authStateSkip, err
} else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) {
RegisterAuthentication(ctx, user, user.ID, "")
if requireAdmin && !user.Admin {
return authStateForbidden, nil
}
return authStateOk, nil
}
}
return authStateSkip, nil
}
}
func (a *Auth) client(requireAdmin bool) func(ctx *gin.Context) (authState, error) {
return func(ctx *gin.Context) (authState, error) {
token := a.readTokenFromRequest(ctx)
if token == "" {
return authStateSkip, nil
}
client, err := a.DB.GetClientByToken(token)
if err != nil {
return authStateSkip, err
}
if client == nil {
return authStateSkip, nil
}
RegisterAuthentication(ctx, nil, client.UserID, client.Token)
now := time.Now()
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateClientTokensLastUsed([]string{client.Token}, &now); err != nil {
return authStateSkip, err
}
}
if requireAdmin {
if user, err := a.DB.GetUserByID(client.UserID); err != nil {
return authStateSkip, err
} else if !user.Admin {
return authStateForbidden, nil
}
}
return authStateOk, nil
}
}
func (a *Auth) application(ctx *gin.Context) (authState, error) {
token := a.readTokenFromRequest(ctx)
if token == "" {
return authStateSkip, nil
}
app, err := a.DB.GetApplicationByToken(token)
if err != nil {
return authStateSkip, err
}
if app == nil {
return authStateSkip, nil
}
RegisterAuthentication(ctx, nil, app.UserID, app.Token)
now := time.Now()
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateApplicationTokenLastUsed(app.Token, &now); err != nil {
return authStateSkip, err
}
}
return authStateOk, nil
}
func (a *Auth) readTokenFromRequest(ctx *gin.Context) string {
if token := a.tokenFromQuery(ctx); token != "" {
return token
} else if token := a.tokenFromXGotifyHeader(ctx); token != "" {
@@ -128,67 +215,3 @@ func (a *Auth) tokenFromAuthorizationHeader(ctx *gin.Context) string {
return authHeader[len(prefix):]
}
func (a *Auth) userFromBasicAuth(ctx *gin.Context) (*model.User, error) {
if name, pass, ok := ctx.Request.BasicAuth(); ok {
if user, err := a.DB.GetUserByName(name); err != nil {
return nil, err
} else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) {
return user, nil
}
}
return nil, nil
}
func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc {
return func(ctx *gin.Context) {
token := a.tokenFromQueryOrHeader(ctx)
user, err := a.userFromBasicAuth(ctx)
if err != nil {
ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
return
}
if user != nil || token != "" {
authenticated, ok, userID, err := auth(token, user)
if err != nil {
ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
return
} else if ok {
RegisterAuthentication(ctx, user, userID, token)
ctx.Next()
return
} else if authenticated {
ctx.AbortWithError(403, errors.New("you are not allowed to access this api"))
return
}
}
ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api"))
}
}
func (a *Auth) Optional() gin.HandlerFunc {
return func(ctx *gin.Context) {
token := a.tokenFromQueryOrHeader(ctx)
user, err := a.userFromBasicAuth(ctx)
if err != nil {
RegisterAuthentication(ctx, nil, 0, "")
ctx.Next()
return
}
if user != nil {
RegisterAuthentication(ctx, user, user.ID, token)
ctx.Next()
return
} else if token != "" {
if tokenClient, err := a.DB.GetClientByToken(token); err == nil && tokenClient != nil {
RegisterAuthentication(ctx, user, tokenClient.UserID, token)
ctx.Next()
return
}
}
RegisterAuthentication(ctx, nil, 0, "")
ctx.Next()
}
}

View File

@@ -82,7 +82,7 @@ func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddlewar
recorder := httptest.NewRecorder()
ctx, _ = gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil)
f()(ctx)
f(ctx)
assert.Equal(s.T(), code, recorder.Code)
return ctx
}
@@ -91,7 +91,7 @@ func (s *AuthenticationSuite) TestNothingProvided() {
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", "/", nil)
s.auth.RequireApplicationToken()(ctx)
s.auth.RequireApplicationToken(ctx)
assert.Equal(s.T(), 401, recorder.Code)
}
@@ -215,9 +215,9 @@ func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddlewa
ctx, _ = gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest("GET", "/", nil)
ctx.Request.Header.Set(key, value)
f()(ctx)
f(ctx)
assert.Equal(s.T(), code, recorder.Code)
return ctx
}
type fMiddleware func() gin.HandlerFunc
type fMiddleware gin.HandlerFunc

View File

@@ -120,8 +120,8 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
g.Use(cors.New(auth.CorsConfig(conf)))
{
g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins)
pluginRoute := g.Group("/plugin/", authentication.RequireClient())
g.GET("/plugin", authentication.RequireClient, pluginHandler.GetPlugins)
pluginRoute := g.Group("/plugin/", authentication.RequireClient)
{
pluginRoute.GET("/:id/config", pluginHandler.GetConfig)
pluginRoute.POST("/:id/config", pluginHandler.UpdateConfig)
@@ -131,7 +131,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
}
}
g.Group("/user").Use(authentication.Optional()).POST("", userHandler.CreateUser)
g.Group("/user").Use(authentication.Optional).POST("", userHandler.CreateUser)
g.OPTIONS("/*any")
@@ -150,11 +150,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
ctx.JSON(200, vInfo)
})
g.Group("/").Use(authentication.RequireApplicationToken()).POST("/message", messageHandler.CreateMessage)
g.Group("/").Use(authentication.RequireApplicationToken).POST("/message", messageHandler.CreateMessage)
clientAuth := g.Group("")
{
clientAuth.Use(authentication.RequireClient())
clientAuth.Use(authentication.RequireClient)
app := clientAuth.Group("/application")
{
app.GET("", applicationHandler.GetApplications)
@@ -206,7 +206,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
authAdmin := g.Group("/user")
{
authAdmin.Use(authentication.RequireAdmin())
authAdmin.Use(authentication.RequireAdmin)
authAdmin.GET("", userHandler.GetUsers)