From 58677b32eff56f6485c3724e5d2761c07dca26a4 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Sat, 11 Apr 2026 21:37:10 +0200 Subject: [PATCH] fix: add client elevatedUntil --- api/client.go | 58 +++++++++++++++++++ api/client_test.go | 74 ++++++++++++++++++++++++ api/oidc.go | 85 ++++++++++++++++++++++++++++ database/client.go | 5 ++ docs/spec.json | 136 +++++++++++++++++++++++++++++++++++++++++++++ model/client.go | 4 ++ model/elevate.go | 17 ++++++ router/router.go | 2 + 8 files changed, 381 insertions(+) create mode 100644 model/elevate.go diff --git a/api/client.go b/api/client.go index 865d74d..3736f31 100644 --- a/api/client.go +++ b/api/client.go @@ -1,7 +1,9 @@ package api import ( + "errors" "fmt" + "time" "github.com/gin-gonic/gin" "github.com/gotify/server/v2/auth" @@ -16,6 +18,7 @@ type ClientDatabase interface { GetClientsByUser(userID uint) ([]*model.Client, error) DeleteClientByID(id uint) error UpdateClient(client *model.Client) error + UpdateClientElevatedUntil(id uint, t *time.Time) error } // The ClientAPI provides handlers for managing clients and applications. @@ -235,6 +238,61 @@ func (a *ClientAPI) DeleteClient(ctx *gin.Context) { }) } +// swagger:operation POST /client:elevate client elevateClient +// +// Elevate a client session. +// +// --- +// consumes: [application/json] +// produces: [application/json] +// parameters: +// - name: body +// in: body +// description: the elevation request +// required: true +// schema: +// $ref: "#/definitions/ElevateRequest" +// security: [clientTokenAuthorizationHeader: [], clientTokenHeader: [], clientTokenQuery: [], basicAuth: []] +// responses: +// 204: +// description: Ok +// 400: +// description: Bad Request +// schema: +// $ref: "#/definitions/Error" +// 401: +// description: Unauthorized +// schema: +// $ref: "#/definitions/Error" +// 404: +// description: Not Found +// schema: +// $ref: "#/definitions/Error" +func (a *ClientAPI) ElevateClient(ctx *gin.Context) { + var params model.ElevateRequest + if err := ctx.Bind(¶ms); err != nil { + return + } + + client, err := a.DB.GetClientByID(params.ID) + if err != nil { + ctx.AbortWithError(500, err) + return + } + if client == nil || client.UserID != auth.GetUserID(ctx) { + ctx.AbortWithError(404, errors.New("client not found")) + return + } + + elevatedUntil := time.Now().Add(time.Duration(params.DurationSeconds) * time.Second) + if err := a.DB.UpdateClientElevatedUntil(client.ID, &elevatedUntil); err != nil { + ctx.AbortWithError(500, err) + return + } + + ctx.Status(204) +} + func (a *ClientAPI) clientExists(token string) bool { client, _ := a.DB.GetClientByToken(token) return client != nil diff --git a/api/client_test.go b/api/client_test.go index e2b3f28..0d73515 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -5,6 +5,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/gin-gonic/gin" "github.com/gotify/server/v2/mode" @@ -223,11 +224,84 @@ func (s *ClientSuite) Test_UpdateClient_WithMissingAttributes_expectBadRequest() assert.Equal(s.T(), 400, s.recorder.Code) } +func (s *ClientSuite) Test_ElevateClient_expectSuccess() { + s.db.User(5).Client(8) + + test.WithUser(s.ctx, 5) + s.withJSONBody(`{"id":8,"durationSeconds":900}`) + + before := time.Now() + s.a.ElevateClient(s.ctx) + after := time.Now() + + assert.Equal(s.T(), 204, s.ctx.Writer.Status()) + client, err := s.db.GetClientByID(8) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), client.ElevatedUntil) + assert.WithinRange(s.T(), *client.ElevatedUntil, before.Add(15*time.Minute), after.Add(15*time.Minute)) +} + +func (s *ClientSuite) Test_ElevateClient_expectNotFoundOnMissingClient() { + s.db.User(5) + + test.WithUser(s.ctx, 5) + s.withJSONBody(`{"id":8,"durationSeconds":900}`) + + s.a.ElevateClient(s.ctx) + + assert.Equal(s.T(), 404, s.recorder.Code) +} + +func (s *ClientSuite) Test_ElevateClient_expectNotFoundOnCurrentUserIsNotOwner() { + s.db.User(5).Client(8) + s.db.User(2) + + test.WithUser(s.ctx, 2) + s.withJSONBody(`{"id":8,"durationSeconds":900}`) + + s.a.ElevateClient(s.ctx) + + assert.Equal(s.T(), 404, s.recorder.Code) + client, err := s.db.GetClientByID(8) + assert.NoError(s.T(), err) + assert.Nil(s.T(), client.ElevatedUntil) +} + +func (s *ClientSuite) Test_ElevateClient_expectBadRequestOnMissingID() { + s.db.User(5) + + test.WithUser(s.ctx, 5) + s.withJSONBody(`{"durationSeconds":900}`) + + s.a.ElevateClient(s.ctx) + + assert.Equal(s.T(), 400, s.recorder.Code) +} + +func (s *ClientSuite) Test_ElevateClient_expectBadRequestOnMissingDuration() { + s.db.User(5).Client(8) + + test.WithUser(s.ctx, 5) + s.withJSONBody(`{"id":8}`) + + s.a.ElevateClient(s.ctx) + + assert.Equal(s.T(), 400, s.recorder.Code) + client, err := s.db.GetClientByID(8) + assert.NoError(s.T(), err) + assert.Nil(s.T(), client.ElevatedUntil) +} + func (s *ClientSuite) withFormData(formData string) { s.ctx.Request = httptest.NewRequest("POST", "/token", strings.NewReader(formData)) s.ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") } +func (s *ClientSuite) withJSONBody(body string) { + s.ctx.Request = httptest.NewRequest("POST", "/client:elevate", strings.NewReader(body)) + s.ctx.Request.Header.Set("Content-Type", "application/json") +} + func withURL(ctx *gin.Context, scheme, host string) { ctx.Set("location", &url.URL{Scheme: scheme, Host: host}) } diff --git a/api/oidc.go b/api/oidc.go index 002fcb1..da2cfd6 100644 --- a/api/oidc.go +++ b/api/oidc.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "io" "log" "net/http" "time" @@ -70,6 +71,7 @@ type pendingOIDCSession struct { RedirectURI string ClientName string CreatedAt time.Time + Elevate *model.ElevateRequest } // OIDCAPI provides handlers for OIDC authentication. @@ -122,6 +124,48 @@ func (a *OIDCAPI) LoginHandler() gin.HandlerFunc { }) } +// swagger:operation GET /auth/oidc/elevate oidc oidcElevate +// +// Start the OIDC flow to elevate an existing client session (browser). +// +// Redirects the user to the OIDC provider's authorization endpoint. After +// successful authentication, the referenced client session is elevated for +// the requested duration. +// +// --- +// parameters: +// - name: id +// in: query +// description: the client id to elevate +// required: true +// type: integer +// format: int64 +// - name: durationSeconds +// in: query +// description: how long the elevation should last, in seconds +// required: true +// type: integer +// responses: +// 302: +// description: Redirect to OIDC provider +// default: +// description: Error +// schema: +// $ref: "#/definitions/Error" +func (a *OIDCAPI) ElevateHandler(ctx *gin.Context) { + var elevate model.ElevateRequest + if err := ctx.BindQuery(&elevate); err != nil { + return + } + state, err := a.generateState() + if err != nil { + ctx.AbortWithError(http.StatusInternalServerError, err) + return + } + a.pendingSessions.Set(time.Now(), state, &pendingOIDCSession{CreatedAt: time.Now(), Elevate: &elevate}) + rp.AuthURLHandler(func() string { return state }, a.Provider)(ctx.Writer, ctx.Request) +} + // swagger:operation GET /auth/oidc/callback oidc oidcCallback // // Handle the OIDC provider callback (browser). @@ -142,6 +186,8 @@ func (a *OIDCAPI) LoginHandler() gin.HandlerFunc { // required: true // type: string // responses: +// 200: +// description: ok // 307: // description: Redirect to UI // default: @@ -160,6 +206,12 @@ func (a *OIDCAPI) CallbackHandler() gin.HandlerFunc { http.Error(w, "unknown or expired state", http.StatusBadRequest) return } + + if session.Elevate != nil { + a.handleElevationCallback(w, session.Elevate, user) + return + } + client, err := a.createClient(session.ClientName, user.ID) if err != nil { http.Error(w, fmt.Sprintf("failed to create client: %v", err), http.StatusInternalServerError) @@ -175,6 +227,39 @@ func (a *OIDCAPI) CallbackHandler() gin.HandlerFunc { return gin.WrapF(rp.CodeExchangeHandler(rp.UserinfoCallback(callback), a.Provider)) } +func (a *OIDCAPI) handleElevationCallback(w http.ResponseWriter, elevate *model.ElevateRequest, user *model.User) { + client, err := a.DB.GetClientByID(elevate.ID) + if err != nil { + http.Error(w, fmt.Sprintf("database error: %v", err), http.StatusInternalServerError) + return + } + if client == nil || client.UserID != user.ID { + http.Error(w, "client not found", http.StatusNotFound) + return + } + elevatedUntil := time.Now().Add(time.Duration(elevate.DurationSeconds) * time.Second) + if err := a.DB.UpdateClientElevatedUntil(client.ID, &elevatedUntil); err != nil { + http.Error(w, fmt.Sprintf("failed to elevate session: %v", err), http.StatusInternalServerError) + return + } + + // The UI rechecks the authentication when the tab is closed. + w.WriteHeader(http.StatusOK) + w.Header().Add("content-type", "text/html") + io.WriteString(w, ` + + + Gotify Session Elevation + + + + +

Gotify session elevation successful. Close this tab to continue.

+ + +`) +} + // swagger:operation POST /auth/oidc/external/authorize oidc externalAuthorize // // Initiate the OIDC authorization flow for a native app. diff --git a/database/client.go b/database/client.go index f85c494..5194610 100644 --- a/database/client.go +++ b/database/client.go @@ -62,3 +62,8 @@ func (d *GormDatabase) UpdateClient(client *model.Client) error { func (d *GormDatabase) UpdateClientTokensLastUsed(tokens []string, t *time.Time) error { return d.DB.Model(&model.Client{}).Where("token IN (?)", tokens).Update("last_used", t).Error } + +// UpdateClientElevatedUntil updates the elevated_until timestamp of a client by token. +func (d *GormDatabase) UpdateClientElevatedUntil(id uint, t *time.Time) error { + return d.DB.Model(&model.Client{}).Where("id = ?", id).Update("elevated_until", t).Error +} diff --git a/docs/spec.json b/docs/spec.json index 2e6db5a..b1a24a7 100644 --- a/docs/spec.json +++ b/docs/spec.json @@ -702,6 +702,9 @@ } ], "responses": { + "200": { + "description": "ok" + }, "307": { "description": "Redirect to UI" }, @@ -714,6 +717,44 @@ } } }, + "/auth/oidc/elevate": { + "get": { + "description": "Redirects the user to the OIDC provider's authorization endpoint. After\nsuccessful authentication, the referenced client session is elevated for\nthe requested duration.", + "tags": [ + "oidc" + ], + "summary": "Start the OIDC flow to elevate an existing client session (browser).", + "operationId": "oidcElevate", + "parameters": [ + { + "type": "integer", + "format": "int64", + "description": "the client id to elevate", + "name": "id", + "in": "query", + "required": true + }, + { + "type": "integer", + "description": "how long the elevation should last, in seconds", + "name": "durationSeconds", + "in": "query", + "required": true + } + ], + "responses": { + "302": { + "description": "Redirect to OIDC provider" + }, + "default": { + "description": "Error", + "schema": { + "$ref": "#/definitions/Error" + } + } + } + } + }, "/auth/oidc/external/authorize": { "post": { "description": "The app generates a PKCE code_verifier and code_challenge, then calls this\nendpoint. The server forwards the code_challenge to the OIDC provider and\nreturns the authorization URL for the app to open in a browser.", @@ -1086,6 +1127,69 @@ } } }, + "/client:elevate": { + "post": { + "security": [ + { + "clientTokenAuthorizationHeader": [] + }, + { + "clientTokenHeader": [] + }, + { + "clientTokenQuery": [] + }, + { + "basicAuth": [] + } + ], + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "client" + ], + "summary": "Elevate a client session.", + "operationId": "elevateClient", + "parameters": [ + { + "description": "the elevation request", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/ElevateRequest" + } + } + ], + "responses": { + "204": { + "description": "Ok" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/Error" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/Error" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/Error" + } + } + } + } + }, "/current/user": { "get": { "security": [ @@ -2432,6 +2536,13 @@ "name" ], "properties": { + "elevatedUntil": { + "description": "The time until which this client's session is elevated.", + "type": "string", + "format": "date-time", + "x-go-name": "ElevatedUntil", + "readOnly": true + }, "id": { "description": "The client id.", "type": "integer", @@ -2512,6 +2623,31 @@ }, "x-go-package": "github.com/gotify/server/v2/model" }, + "ElevateRequest": { + "type": "object", + "title": "ElevateRequest parameters for client elevation.", + "required": [ + "id", + "durationSeconds" + ], + "properties": { + "durationSeconds": { + "description": "How long the elevation should last, in seconds.", + "type": "integer", + "format": "int64", + "x-go-name": "DurationSeconds", + "example": 900 + }, + "id": { + "description": "The client ID to elevate.", + "type": "integer", + "format": "int64", + "x-go-name": "ID", + "example": 5 + } + }, + "x-go-package": "github.com/gotify/server/v2/model" + }, "Error": { "description": "The Error contains error relevant information.", "type": "object", diff --git a/model/client.go b/model/client.go index 9b96c82..674be27 100644 --- a/model/client.go +++ b/model/client.go @@ -31,4 +31,8 @@ type Client struct { // read only: true // example: 2019-01-01T00:00:00Z LastUsed *time.Time `json:"lastUsed"` + // The time until which this client's session is elevated. + // + // read only: true + ElevatedUntil *time.Time `json:"elevatedUntil,omitempty"` } diff --git a/model/elevate.go b/model/elevate.go new file mode 100644 index 0000000..8ffd688 --- /dev/null +++ b/model/elevate.go @@ -0,0 +1,17 @@ +package model + +// ElevateRequest parameters for client elevation. +// +// swagger:model ElevateRequest +type ElevateRequest struct { + // The client ID to elevate. + // + // required: true + // example: 5 + ID uint `form:"id" query:"id" json:"id" binding:"required"` + // How long the elevation should last, in seconds. + // + // required: true + // example: 900 + DurationSeconds int `form:"durationSeconds" query:"durationSeconds" json:"durationSeconds" binding:"required"` +} diff --git a/router/router.go b/router/router.go index de05b1c..f13bc2e 100644 --- a/router/router.go +++ b/router/router.go @@ -113,6 +113,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co oidcGroup.GET("/callback", oidcHandler.CallbackHandler()) oidcGroup.POST("/external/authorize", oidcHandler.ExternalAuthorizeHandler) oidcGroup.POST("/external/token", oidcHandler.ExternalTokenHandler) + oidcGroup.GET("/elevate", oidcHandler.ElevateHandler) } g.Match([]string{"GET", "HEAD"}, "/health", healthHandler.Health) @@ -214,6 +215,7 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co client.PUT("/:id", clientHandler.UpdateClient) } + client.POST("/client:elevate", clientHandler.ElevateClient) message := clientAuth.Group("/message") {