fix: add client elevatedUntil

This commit is contained in:
Jannis Mattheis
2026-04-11 21:37:10 +02:00
parent 410571dd18
commit 58677b32ef
8 changed files with 381 additions and 0 deletions

View File

@@ -1,7 +1,9 @@
package api
import (
"errors"
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/auth"
@@ -16,6 +18,7 @@ type ClientDatabase interface {
GetClientsByUser(userID uint) ([]*model.Client, error)
DeleteClientByID(id uint) error
UpdateClient(client *model.Client) error
UpdateClientElevatedUntil(id uint, t *time.Time) error
}
// The ClientAPI provides handlers for managing clients and applications.
@@ -235,6 +238,61 @@ func (a *ClientAPI) DeleteClient(ctx *gin.Context) {
})
}
// swagger:operation POST /client:elevate client elevateClient
//
// Elevate a client session.
//
// ---
// consumes: [application/json]
// produces: [application/json]
// parameters:
// - name: body
// in: body
// description: the elevation request
// required: true
// schema:
// $ref: "#/definitions/ElevateRequest"
// security: [clientTokenAuthorizationHeader: [], clientTokenHeader: [], clientTokenQuery: [], basicAuth: []]
// responses:
// 204:
// description: Ok
// 400:
// description: Bad Request
// schema:
// $ref: "#/definitions/Error"
// 401:
// description: Unauthorized
// schema:
// $ref: "#/definitions/Error"
// 404:
// description: Not Found
// schema:
// $ref: "#/definitions/Error"
func (a *ClientAPI) ElevateClient(ctx *gin.Context) {
var params model.ElevateRequest
if err := ctx.Bind(&params); err != nil {
return
}
client, err := a.DB.GetClientByID(params.ID)
if err != nil {
ctx.AbortWithError(500, err)
return
}
if client == nil || client.UserID != auth.GetUserID(ctx) {
ctx.AbortWithError(404, errors.New("client not found"))
return
}
elevatedUntil := time.Now().Add(time.Duration(params.DurationSeconds) * time.Second)
if err := a.DB.UpdateClientElevatedUntil(client.ID, &elevatedUntil); err != nil {
ctx.AbortWithError(500, err)
return
}
ctx.Status(204)
}
func (a *ClientAPI) clientExists(token string) bool {
client, _ := a.DB.GetClientByToken(token)
return client != nil

View File

@@ -5,6 +5,7 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/mode"
@@ -223,11 +224,84 @@ func (s *ClientSuite) Test_UpdateClient_WithMissingAttributes_expectBadRequest()
assert.Equal(s.T(), 400, s.recorder.Code)
}
func (s *ClientSuite) Test_ElevateClient_expectSuccess() {
s.db.User(5).Client(8)
test.WithUser(s.ctx, 5)
s.withJSONBody(`{"id":8,"durationSeconds":900}`)
before := time.Now()
s.a.ElevateClient(s.ctx)
after := time.Now()
assert.Equal(s.T(), 204, s.ctx.Writer.Status())
client, err := s.db.GetClientByID(8)
assert.NoError(s.T(), err)
assert.NotNil(s.T(), client.ElevatedUntil)
assert.WithinRange(s.T(), *client.ElevatedUntil, before.Add(15*time.Minute), after.Add(15*time.Minute))
}
func (s *ClientSuite) Test_ElevateClient_expectNotFoundOnMissingClient() {
s.db.User(5)
test.WithUser(s.ctx, 5)
s.withJSONBody(`{"id":8,"durationSeconds":900}`)
s.a.ElevateClient(s.ctx)
assert.Equal(s.T(), 404, s.recorder.Code)
}
func (s *ClientSuite) Test_ElevateClient_expectNotFoundOnCurrentUserIsNotOwner() {
s.db.User(5).Client(8)
s.db.User(2)
test.WithUser(s.ctx, 2)
s.withJSONBody(`{"id":8,"durationSeconds":900}`)
s.a.ElevateClient(s.ctx)
assert.Equal(s.T(), 404, s.recorder.Code)
client, err := s.db.GetClientByID(8)
assert.NoError(s.T(), err)
assert.Nil(s.T(), client.ElevatedUntil)
}
func (s *ClientSuite) Test_ElevateClient_expectBadRequestOnMissingID() {
s.db.User(5)
test.WithUser(s.ctx, 5)
s.withJSONBody(`{"durationSeconds":900}`)
s.a.ElevateClient(s.ctx)
assert.Equal(s.T(), 400, s.recorder.Code)
}
func (s *ClientSuite) Test_ElevateClient_expectBadRequestOnMissingDuration() {
s.db.User(5).Client(8)
test.WithUser(s.ctx, 5)
s.withJSONBody(`{"id":8}`)
s.a.ElevateClient(s.ctx)
assert.Equal(s.T(), 400, s.recorder.Code)
client, err := s.db.GetClientByID(8)
assert.NoError(s.T(), err)
assert.Nil(s.T(), client.ElevatedUntil)
}
func (s *ClientSuite) withFormData(formData string) {
s.ctx.Request = httptest.NewRequest("POST", "/token", strings.NewReader(formData))
s.ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
func (s *ClientSuite) withJSONBody(body string) {
s.ctx.Request = httptest.NewRequest("POST", "/client:elevate", strings.NewReader(body))
s.ctx.Request.Header.Set("Content-Type", "application/json")
}
func withURL(ctx *gin.Context, scheme, host string) {
ctx.Set("location", &url.URL{Scheme: scheme, Host: host})
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net/http"
"time"
@@ -70,6 +71,7 @@ type pendingOIDCSession struct {
RedirectURI string
ClientName string
CreatedAt time.Time
Elevate *model.ElevateRequest
}
// OIDCAPI provides handlers for OIDC authentication.
@@ -122,6 +124,48 @@ func (a *OIDCAPI) LoginHandler() gin.HandlerFunc {
})
}
// swagger:operation GET /auth/oidc/elevate oidc oidcElevate
//
// Start the OIDC flow to elevate an existing client session (browser).
//
// Redirects the user to the OIDC provider's authorization endpoint. After
// successful authentication, the referenced client session is elevated for
// the requested duration.
//
// ---
// parameters:
// - name: id
// in: query
// description: the client id to elevate
// required: true
// type: integer
// format: int64
// - name: durationSeconds
// in: query
// description: how long the elevation should last, in seconds
// required: true
// type: integer
// responses:
// 302:
// description: Redirect to OIDC provider
// default:
// description: Error
// schema:
// $ref: "#/definitions/Error"
func (a *OIDCAPI) ElevateHandler(ctx *gin.Context) {
var elevate model.ElevateRequest
if err := ctx.BindQuery(&elevate); err != nil {
return
}
state, err := a.generateState()
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
a.pendingSessions.Set(time.Now(), state, &pendingOIDCSession{CreatedAt: time.Now(), Elevate: &elevate})
rp.AuthURLHandler(func() string { return state }, a.Provider)(ctx.Writer, ctx.Request)
}
// swagger:operation GET /auth/oidc/callback oidc oidcCallback
//
// Handle the OIDC provider callback (browser).
@@ -142,6 +186,8 @@ func (a *OIDCAPI) LoginHandler() gin.HandlerFunc {
// required: true
// type: string
// responses:
// 200:
// description: ok
// 307:
// description: Redirect to UI
// default:
@@ -160,6 +206,12 @@ func (a *OIDCAPI) CallbackHandler() gin.HandlerFunc {
http.Error(w, "unknown or expired state", http.StatusBadRequest)
return
}
if session.Elevate != nil {
a.handleElevationCallback(w, session.Elevate, user)
return
}
client, err := a.createClient(session.ClientName, user.ID)
if err != nil {
http.Error(w, fmt.Sprintf("failed to create client: %v", err), http.StatusInternalServerError)
@@ -175,6 +227,39 @@ func (a *OIDCAPI) CallbackHandler() gin.HandlerFunc {
return gin.WrapF(rp.CodeExchangeHandler(rp.UserinfoCallback(callback), a.Provider))
}
func (a *OIDCAPI) handleElevationCallback(w http.ResponseWriter, elevate *model.ElevateRequest, user *model.User) {
client, err := a.DB.GetClientByID(elevate.ID)
if err != nil {
http.Error(w, fmt.Sprintf("database error: %v", err), http.StatusInternalServerError)
return
}
if client == nil || client.UserID != user.ID {
http.Error(w, "client not found", http.StatusNotFound)
return
}
elevatedUntil := time.Now().Add(time.Duration(elevate.DurationSeconds) * time.Second)
if err := a.DB.UpdateClientElevatedUntil(client.ID, &elevatedUntil); err != nil {
http.Error(w, fmt.Sprintf("failed to elevate session: %v", err), http.StatusInternalServerError)
return
}
// The UI rechecks the authentication when the tab is closed.
w.WriteHeader(http.StatusOK)
w.Header().Add("content-type", "text/html")
io.WriteString(w, `<!DOCTYPE html>
<html lang="en">
<head>
<title>Gotify Session Elevation</title>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
</head>
<body>
<h1 style="text-align:center">Gotify session elevation successful. Close this tab to continue.</h1>
<script>window.close();</script>
</body>
</html>`)
}
// swagger:operation POST /auth/oidc/external/authorize oidc externalAuthorize
//
// Initiate the OIDC authorization flow for a native app.