mirror of
https://github.com/gotify/server.git
synced 2026-05-31 15:31:40 +08:00
fix: simplify authentication handling
This commit is contained in:
@@ -10,15 +10,20 @@ import (
|
||||
"github.com/gotify/server/v2/model"
|
||||
)
|
||||
|
||||
type authState int
|
||||
|
||||
const (
|
||||
headerName = "X-Gotify-Key"
|
||||
authStateSkip authState = iota
|
||||
authStateForbidden
|
||||
authStateOk
|
||||
)
|
||||
|
||||
const headerName = "X-Gotify-Key"
|
||||
|
||||
// The Database interface for encapsulating database access.
|
||||
type Database interface {
|
||||
GetApplicationByToken(token string) (*model.Application, error)
|
||||
GetClientByToken(token string) (*model.Client, error)
|
||||
GetPluginConfByToken(token string) (*model.PluginConf, error)
|
||||
GetUserByName(name string) (*model.User, error)
|
||||
GetUserByID(id uint) (*model.User, error)
|
||||
UpdateClientTokensLastUsed(tokens []string, t *time.Time) error
|
||||
@@ -30,72 +35,154 @@ type Auth struct {
|
||||
DB Database
|
||||
}
|
||||
|
||||
type authenticate func(tokenID string, user *model.User) (authenticated, success bool, userId uint, err error)
|
||||
|
||||
// RequireAdmin returns a gin middleware which requires a client token or basic authentication header to be supplied
|
||||
// with the request. Also the authenticated user must be an administrator.
|
||||
func (a *Auth) RequireAdmin() gin.HandlerFunc {
|
||||
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
|
||||
if user != nil {
|
||||
return true, user.Admin, user.ID, nil
|
||||
}
|
||||
if token, err := a.DB.GetClientByToken(tokenID); err != nil {
|
||||
return false, false, 0, err
|
||||
} else if token != nil {
|
||||
user, err := a.DB.GetUserByID(token.UserID)
|
||||
if err != nil {
|
||||
return false, false, token.UserID, err
|
||||
}
|
||||
return true, user.Admin, token.UserID, nil
|
||||
}
|
||||
return false, false, 0, nil
|
||||
})
|
||||
func (a *Auth) RequireAdmin(ctx *gin.Context) {
|
||||
a.evaluateOr401(ctx, a.user(true), a.client(true))
|
||||
}
|
||||
|
||||
// RequireClient returns a gin middleware which requires a client token or basic authentication header to be supplied
|
||||
// with the request.
|
||||
func (a *Auth) RequireClient() gin.HandlerFunc {
|
||||
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
|
||||
if user != nil {
|
||||
return true, true, user.ID, nil
|
||||
}
|
||||
if client, err := a.DB.GetClientByToken(tokenID); err != nil {
|
||||
return false, false, 0, err
|
||||
} else if client != nil {
|
||||
now := time.Now()
|
||||
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
|
||||
if err := a.DB.UpdateClientTokensLastUsed([]string{tokenID}, &now); err != nil {
|
||||
return false, false, 0, err
|
||||
}
|
||||
}
|
||||
return true, true, client.UserID, nil
|
||||
}
|
||||
return false, false, 0, nil
|
||||
})
|
||||
func (a *Auth) RequireClient(ctx *gin.Context) {
|
||||
a.evaluateOr401(ctx, a.user(false), a.client(false))
|
||||
}
|
||||
|
||||
// RequireApplicationToken returns a gin middleware which requires an application token to be supplied with the request.
|
||||
func (a *Auth) RequireApplicationToken() gin.HandlerFunc {
|
||||
return a.requireToken(func(tokenID string, user *model.User) (bool, bool, uint, error) {
|
||||
if user != nil {
|
||||
return true, false, 0, nil
|
||||
}
|
||||
if app, err := a.DB.GetApplicationByToken(tokenID); err != nil {
|
||||
return false, false, 0, err
|
||||
} else if app != nil {
|
||||
now := time.Now()
|
||||
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
|
||||
if err := a.DB.UpdateApplicationTokenLastUsed(tokenID, &now); err != nil {
|
||||
return false, false, 0, err
|
||||
}
|
||||
}
|
||||
return true, true, app.UserID, nil
|
||||
}
|
||||
return false, false, 0, nil
|
||||
})
|
||||
func (a *Auth) RequireApplicationToken(ctx *gin.Context) {
|
||||
if a.evaluate(ctx, a.application) {
|
||||
return
|
||||
}
|
||||
state, err := a.user(false)(ctx)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(500, err)
|
||||
}
|
||||
if state != authStateSkip {
|
||||
// Return to the user that it's valid authentication, but we don't allow user auth for application endpoints.
|
||||
a.abort403(ctx)
|
||||
return
|
||||
}
|
||||
a.abort401(ctx)
|
||||
}
|
||||
|
||||
func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) string {
|
||||
func (a *Auth) Optional(ctx *gin.Context) {
|
||||
if !a.evaluate(ctx, a.user(false), a.client(false)) {
|
||||
RegisterAuthentication(ctx, nil, 0, "")
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) evaluate(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) bool {
|
||||
for _, fn := range funcs {
|
||||
state, err := fn(ctx)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(500, err)
|
||||
return true
|
||||
}
|
||||
switch state {
|
||||
case authStateForbidden:
|
||||
a.abort403(ctx)
|
||||
return true
|
||||
case authStateOk:
|
||||
ctx.Next()
|
||||
return true
|
||||
case authStateSkip:
|
||||
continue
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Auth) evaluateOr401(ctx *gin.Context, funcs ...func(ctx *gin.Context) (authState, error)) {
|
||||
if !a.evaluate(ctx, funcs...) {
|
||||
a.abort401(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) abort401(ctx *gin.Context) {
|
||||
ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api"))
|
||||
}
|
||||
|
||||
func (a *Auth) abort403(ctx *gin.Context) {
|
||||
ctx.AbortWithError(403, errors.New("you are not allowed to access this api"))
|
||||
}
|
||||
|
||||
func (a *Auth) user(requireAdmin bool) func(ctx *gin.Context) (authState, error) {
|
||||
return func(ctx *gin.Context) (authState, error) {
|
||||
if name, pass, ok := ctx.Request.BasicAuth(); ok {
|
||||
if user, err := a.DB.GetUserByName(name); err != nil {
|
||||
return authStateSkip, err
|
||||
} else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) {
|
||||
RegisterAuthentication(ctx, user, user.ID, "")
|
||||
|
||||
if requireAdmin && !user.Admin {
|
||||
return authStateForbidden, nil
|
||||
}
|
||||
return authStateOk, nil
|
||||
}
|
||||
}
|
||||
return authStateSkip, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) client(requireAdmin bool) func(ctx *gin.Context) (authState, error) {
|
||||
return func(ctx *gin.Context) (authState, error) {
|
||||
token := a.readTokenFromRequest(ctx)
|
||||
if token == "" {
|
||||
return authStateSkip, nil
|
||||
}
|
||||
client, err := a.DB.GetClientByToken(token)
|
||||
if err != nil {
|
||||
return authStateSkip, err
|
||||
}
|
||||
if client == nil {
|
||||
return authStateSkip, nil
|
||||
}
|
||||
RegisterAuthentication(ctx, nil, client.UserID, client.Token)
|
||||
|
||||
now := time.Now()
|
||||
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
|
||||
if err := a.DB.UpdateClientTokensLastUsed([]string{client.Token}, &now); err != nil {
|
||||
return authStateSkip, err
|
||||
}
|
||||
}
|
||||
|
||||
if requireAdmin {
|
||||
if user, err := a.DB.GetUserByID(client.UserID); err != nil {
|
||||
return authStateSkip, err
|
||||
} else if !user.Admin {
|
||||
return authStateForbidden, nil
|
||||
}
|
||||
}
|
||||
|
||||
return authStateOk, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) application(ctx *gin.Context) (authState, error) {
|
||||
token := a.readTokenFromRequest(ctx)
|
||||
if token == "" {
|
||||
return authStateSkip, nil
|
||||
}
|
||||
app, err := a.DB.GetApplicationByToken(token)
|
||||
if err != nil {
|
||||
return authStateSkip, err
|
||||
}
|
||||
if app == nil {
|
||||
return authStateSkip, nil
|
||||
}
|
||||
RegisterAuthentication(ctx, nil, app.UserID, app.Token)
|
||||
|
||||
now := time.Now()
|
||||
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
|
||||
if err := a.DB.UpdateApplicationTokenLastUsed(app.Token, &now); err != nil {
|
||||
return authStateSkip, err
|
||||
}
|
||||
}
|
||||
|
||||
return authStateOk, nil
|
||||
}
|
||||
|
||||
func (a *Auth) readTokenFromRequest(ctx *gin.Context) string {
|
||||
if token := a.tokenFromQuery(ctx); token != "" {
|
||||
return token
|
||||
} else if token := a.tokenFromXGotifyHeader(ctx); token != "" {
|
||||
@@ -128,67 +215,3 @@ func (a *Auth) tokenFromAuthorizationHeader(ctx *gin.Context) string {
|
||||
|
||||
return authHeader[len(prefix):]
|
||||
}
|
||||
|
||||
func (a *Auth) userFromBasicAuth(ctx *gin.Context) (*model.User, error) {
|
||||
if name, pass, ok := ctx.Request.BasicAuth(); ok {
|
||||
if user, err := a.DB.GetUserByName(name); err != nil {
|
||||
return nil, err
|
||||
} else if user != nil && password.ComparePassword(user.Pass, []byte(pass)) {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Auth) requireToken(auth authenticate) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
token := a.tokenFromQueryOrHeader(ctx)
|
||||
user, err := a.userFromBasicAuth(ctx)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
|
||||
return
|
||||
}
|
||||
|
||||
if user != nil || token != "" {
|
||||
authenticated, ok, userID, err := auth(token, user)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(500, errors.New("an error occurred while authenticating user"))
|
||||
return
|
||||
} else if ok {
|
||||
RegisterAuthentication(ctx, user, userID, token)
|
||||
ctx.Next()
|
||||
return
|
||||
} else if authenticated {
|
||||
ctx.AbortWithError(403, errors.New("you are not allowed to access this api"))
|
||||
return
|
||||
}
|
||||
}
|
||||
ctx.AbortWithError(401, errors.New("you need to provide a valid access token or user credentials to access this api"))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) Optional() gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
token := a.tokenFromQueryOrHeader(ctx)
|
||||
user, err := a.userFromBasicAuth(ctx)
|
||||
if err != nil {
|
||||
RegisterAuthentication(ctx, nil, 0, "")
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
RegisterAuthentication(ctx, user, user.ID, token)
|
||||
ctx.Next()
|
||||
return
|
||||
} else if token != "" {
|
||||
if tokenClient, err := a.DB.GetClientByToken(token); err == nil && tokenClient != nil {
|
||||
RegisterAuthentication(ctx, user, tokenClient.UserID, token)
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
RegisterAuthentication(ctx, nil, 0, "")
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddlewar
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ = gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil)
|
||||
f()(ctx)
|
||||
f(ctx)
|
||||
assert.Equal(s.T(), code, recorder.Code)
|
||||
return ctx
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func (s *AuthenticationSuite) TestNothingProvided() {
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", "/", nil)
|
||||
s.auth.RequireApplicationToken()(ctx)
|
||||
s.auth.RequireApplicationToken(ctx)
|
||||
assert.Equal(s.T(), 401, recorder.Code)
|
||||
}
|
||||
|
||||
@@ -215,9 +215,9 @@ func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddlewa
|
||||
ctx, _ = gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest("GET", "/", nil)
|
||||
ctx.Request.Header.Set(key, value)
|
||||
f()(ctx)
|
||||
f(ctx)
|
||||
assert.Equal(s.T(), code, recorder.Code)
|
||||
return ctx
|
||||
}
|
||||
|
||||
type fMiddleware func() gin.HandlerFunc
|
||||
type fMiddleware gin.HandlerFunc
|
||||
|
||||
@@ -120,8 +120,8 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
||||
g.Use(cors.New(auth.CorsConfig(conf)))
|
||||
|
||||
{
|
||||
g.GET("/plugin", authentication.RequireClient(), pluginHandler.GetPlugins)
|
||||
pluginRoute := g.Group("/plugin/", authentication.RequireClient())
|
||||
g.GET("/plugin", authentication.RequireClient, pluginHandler.GetPlugins)
|
||||
pluginRoute := g.Group("/plugin/", authentication.RequireClient)
|
||||
{
|
||||
pluginRoute.GET("/:id/config", pluginHandler.GetConfig)
|
||||
pluginRoute.POST("/:id/config", pluginHandler.UpdateConfig)
|
||||
@@ -131,7 +131,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
||||
}
|
||||
}
|
||||
|
||||
g.Group("/user").Use(authentication.Optional()).POST("", userHandler.CreateUser)
|
||||
g.Group("/user").Use(authentication.Optional).POST("", userHandler.CreateUser)
|
||||
|
||||
g.OPTIONS("/*any")
|
||||
|
||||
@@ -150,11 +150,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
||||
ctx.JSON(200, vInfo)
|
||||
})
|
||||
|
||||
g.Group("/").Use(authentication.RequireApplicationToken()).POST("/message", messageHandler.CreateMessage)
|
||||
g.Group("/").Use(authentication.RequireApplicationToken).POST("/message", messageHandler.CreateMessage)
|
||||
|
||||
clientAuth := g.Group("")
|
||||
{
|
||||
clientAuth.Use(authentication.RequireClient())
|
||||
clientAuth.Use(authentication.RequireClient)
|
||||
app := clientAuth.Group("/application")
|
||||
{
|
||||
app.GET("", applicationHandler.GetApplications)
|
||||
@@ -206,7 +206,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co
|
||||
|
||||
authAdmin := g.Group("/user")
|
||||
{
|
||||
authAdmin.Use(authentication.RequireAdmin())
|
||||
authAdmin.Use(authentication.RequireAdmin)
|
||||
|
||||
authAdmin.GET("", userHandler.GetUsers)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user