feat(auth): add support for websockets in auth file parsing and patching

- Introduced parsing logic to handle `websockets` field in auth files.
- Extended `PatchAuthFileFields` to update `websockets` and arbitrary nested metadata fields.
- Added tests to validate `websockets` parsing, updating, and persistence.
This commit is contained in:
Luis Pater
2026-05-26 00:49:36 +08:00
parent a0bb1f3a2b
commit 167edfec6c
3 changed files with 522 additions and 135 deletions

View File

@@ -352,6 +352,18 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
fileData["note"] = trimmed
}
}
if wv := gjson.GetBytes(data, "websockets"); wv.Exists() {
switch wv.Type {
case gjson.True:
fileData["websockets"] = true
case gjson.False:
fileData["websockets"] = false
case gjson.String:
if parsed, errParse := strconv.ParseBool(strings.TrimSpace(wv.String())); errParse == nil {
fileData["websockets"] = parsed
}
}
}
}
files = append(files, fileData)
@@ -472,9 +484,43 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
}
}
}
if websockets, ok := authWebsocketsValue(auth); ok {
entry["websockets"] = websockets
}
return entry
}
func authWebsocketsValue(auth *coreauth.Auth) (bool, bool) {
if auth == nil {
return false, false
}
if auth.Attributes != nil {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed, true
}
}
}
if auth.Metadata == nil {
return false, false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false, false
}
switch v := raw.(type) {
case bool:
return v, true
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed, true
}
}
return false, false
}
func authProjectID(auth *coreauth.Auth) string {
if auth == nil {
return ""
@@ -1150,31 +1196,37 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
// PatchAuthFileFields updates arbitrary metadata fields of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Headers map[string]string `json:"headers"`
Priority *int `json:"priority"`
Note *string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
var req map[string]json.RawMessage
decoder := json.NewDecoder(c.Request.Body)
decoder.UseNumber()
if err := decoder.Decode(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
nameRaw, ok := req["name"]
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
var nameValue string
if err := json.Unmarshal(nameRaw, &nameValue); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
name := strings.TrimSpace(nameValue)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
delete(req, "name")
ctx := c.Request.Context()
@@ -1198,136 +1250,35 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
}
changed := false
if req.Prefix != nil {
prefix := strings.TrimSpace(*req.Prefix)
targetAuth.Prefix = prefix
touchedRoots := make(map[string]struct{}, len(req))
for key, rawValue := range req {
fieldPath := strings.TrimSpace(key)
if fieldPath == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "field name is required"})
return
}
value, errDecode := decodeAuthFileFieldValue(rawValue)
if errDecode != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid field %s", fieldPath)})
return
}
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if prefix == "" {
delete(targetAuth.Metadata, "prefix")
} else {
targetAuth.Metadata["prefix"] = prefix
if fieldPath == "headers" {
applyAuthFileHeadersPatch(targetAuth, value)
} else if errSet := setAuthFileMetadataValue(targetAuth.Metadata, fieldPath, value); errSet != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errSet.Error()})
return
}
if root := rootAuthFileField(fieldPath); root != "" {
touchedRoots[root] = struct{}{}
}
changed = true
}
if req.ProxyURL != nil {
proxyURL := strings.TrimSpace(*req.ProxyURL)
targetAuth.ProxyURL = proxyURL
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if proxyURL == "" {
delete(targetAuth.Metadata, "proxy_url")
} else {
targetAuth.Metadata["proxy_url"] = proxyURL
}
changed = true
}
if len(req.Headers) > 0 {
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
nextHeaders := make(map[string]string, len(existingHeaders))
for k, v := range existingHeaders {
nextHeaders[k] = v
}
headerChanged := false
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
if _, ok := nextHeaders[name]; ok {
delete(nextHeaders, name)
headerChanged = true
}
if targetAuth.Attributes != nil {
if _, ok := targetAuth.Attributes[attrKey]; ok {
headerChanged = true
}
}
continue
}
if prev, ok := nextHeaders[name]; !ok || prev != val {
headerChanged = true
}
nextHeaders[name] = val
if targetAuth.Attributes != nil {
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
headerChanged = true
}
} else {
headerChanged = true
}
}
if headerChanged {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
delete(nextHeaders, name)
delete(targetAuth.Attributes, attrKey)
continue
}
nextHeaders[name] = val
targetAuth.Attributes[attrKey] = val
}
if len(nextHeaders) == 0 {
delete(targetAuth.Metadata, "headers")
} else {
metaHeaders := make(map[string]any, len(nextHeaders))
for k, v := range nextHeaders {
metaHeaders[k] = v
}
targetAuth.Metadata["headers"] = metaHeaders
}
changed = true
}
}
if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
if req.Priority != nil {
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
delete(targetAuth.Attributes, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
}
}
if req.Note != nil {
trimmedNote := strings.TrimSpace(*req.Note)
if trimmedNote == "" {
delete(targetAuth.Metadata, "note")
delete(targetAuth.Attributes, "note")
} else {
targetAuth.Metadata["note"] = trimmedNote
targetAuth.Attributes["note"] = trimmedNote
}
}
changed = true
if changed {
syncAuthFileMetadataFields(targetAuth, touchedRoots)
}
if !changed {
@@ -1345,6 +1296,268 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func decodeAuthFileFieldValue(raw json.RawMessage) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(raw))
decoder.UseNumber()
var value any
if err := decoder.Decode(&value); err != nil {
return nil, err
}
return value, nil
}
func rootAuthFileField(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if idx := strings.Index(path, "."); idx >= 0 {
return strings.TrimSpace(path[:idx])
}
return path
}
func setAuthFileMetadataValue(metadata map[string]any, path string, value any) error {
if metadata == nil {
return fmt.Errorf("metadata is nil")
}
parts := strings.Split(path, ".")
current := metadata
for i, rawPart := range parts {
part := strings.TrimSpace(rawPart)
if part == "" {
return fmt.Errorf("invalid field path: %s", path)
}
if i == len(parts)-1 {
current[part] = value
return nil
}
next, ok := current[part].(map[string]any)
if !ok {
next = make(map[string]any)
current[part] = next
}
current = next
}
return nil
}
func applyAuthFileHeadersPatch(auth *coreauth.Auth, value any) {
if auth == nil {
return
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
headersPatch, ok := authFileHeadersStringMap(value)
if !ok {
auth.Metadata["headers"] = value
return
}
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(auth.Metadata)
nextHeaders := make(map[string]string, len(existingHeaders))
for key, val := range existingHeaders {
nextHeaders[key] = val
}
for key, value := range headersPatch {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
if val == "" {
delete(nextHeaders, name)
continue
}
nextHeaders[name] = val
}
if len(nextHeaders) == 0 {
delete(auth.Metadata, "headers")
return
}
metaHeaders := make(map[string]any, len(nextHeaders))
for key, value := range nextHeaders {
metaHeaders[key] = value
}
auth.Metadata["headers"] = metaHeaders
}
func authFileHeadersStringMap(value any) (map[string]string, bool) {
switch typed := value.(type) {
case map[string]string:
return typed, true
case map[string]any:
out := make(map[string]string, len(typed))
for key, rawValue := range typed {
value, ok := rawValue.(string)
if !ok {
return nil, false
}
out[key] = value
}
return out, true
default:
return nil, false
}
}
func syncAuthFileMetadataFields(auth *coreauth.Auth, touchedRoots map[string]struct{}) {
if auth == nil || len(touchedRoots) == 0 {
return
}
if _, ok := touchedRoots["prefix"]; ok {
if prefix, okString := auth.Metadata["prefix"].(string); okString {
auth.Prefix = strings.TrimSpace(prefix)
}
}
if _, ok := touchedRoots["proxy_url"]; ok {
if proxyURL, okString := auth.Metadata["proxy_url"].(string); okString {
auth.ProxyURL = strings.TrimSpace(proxyURL)
}
}
if _, ok := touchedRoots["headers"]; ok {
syncAuthFileHeaderAttributes(auth)
}
if _, ok := touchedRoots["priority"]; ok {
syncAuthFilePriorityAttribute(auth)
}
if _, ok := touchedRoots["note"]; ok {
syncAuthFileNoteAttribute(auth)
}
if _, ok := touchedRoots["websockets"]; ok {
syncAuthFileWebsocketsAttribute(auth)
}
if _, ok := touchedRoots["disabled"]; ok {
syncAuthFileDisabledState(auth)
}
}
func syncAuthFileHeaderAttributes(auth *coreauth.Auth) {
if auth == nil {
return
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
for key := range auth.Attributes {
if strings.HasPrefix(key, "header:") {
delete(auth.Attributes, key)
}
}
for name, value := range coreauth.ExtractCustomHeadersFromMetadata(auth.Metadata) {
auth.Attributes["header:"+name] = value
}
}
func syncAuthFilePriorityAttribute(auth *coreauth.Auth) {
if auth == nil {
return
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
priority, ok := authFileIntValue(auth.Metadata["priority"])
if !ok {
delete(auth.Attributes, "priority")
return
}
if priority == 0 {
delete(auth.Attributes, "priority")
return
}
auth.Attributes["priority"] = strconv.Itoa(priority)
}
func authFileIntValue(value any) (int, bool) {
switch typed := value.(type) {
case int:
return typed, true
case int64:
return int(typed), true
case float64:
return int(typed), true
case json.Number:
if i, err := typed.Int64(); err == nil {
return int(i), true
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil {
return i, true
}
}
return 0, false
}
func syncAuthFileNoteAttribute(auth *coreauth.Auth) {
if auth == nil {
return
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
note, ok := auth.Metadata["note"].(string)
if !ok {
delete(auth.Attributes, "note")
return
}
note = strings.TrimSpace(note)
if note == "" {
delete(auth.Attributes, "note")
return
}
auth.Attributes["note"] = note
}
func syncAuthFileWebsocketsAttribute(auth *coreauth.Auth) {
if auth == nil {
return
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
websockets, ok := authFileBoolValue(auth.Metadata["websockets"])
if !ok {
delete(auth.Attributes, "websockets")
return
}
auth.Attributes["websockets"] = strconv.FormatBool(websockets)
}
func authFileBoolValue(value any) (bool, bool) {
switch typed := value.(type) {
case bool:
return typed, true
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(typed))
if errParse == nil {
return parsed, true
}
}
return false, false
}
func syncAuthFileDisabledState(auth *coreauth.Auth) {
if auth == nil {
return
}
disabled, ok := authFileBoolValue(auth.Metadata["disabled"])
if !ok {
return
}
auth.Disabled = disabled
if disabled {
auth.Status = coreauth.StatusDisabled
if strings.TrimSpace(auth.StatusMessage) == "" {
auth.StatusMessage = "disabled via management API"
}
return
}
auth.Status = coreauth.StatusActive
auth.StatusMessage = ""
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return

View File

@@ -5,11 +5,14 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
fileauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
)
@@ -162,3 +165,118 @@ func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
}
}
func TestPatchAuthFileFields_WebsocketsFalseIsUpdate(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: "codex.json",
FileName: "codex.json",
Provider: "codex",
Attributes: map[string]string{
"path": "/tmp/codex.json",
"websockets": "true",
},
Metadata: map[string]any{
"type": "codex",
"websockets": true,
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
body := `{"name":"codex.json","websockets":false}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
updated, ok := manager.GetByID("codex.json")
if !ok || updated == nil {
t.Fatalf("expected auth record to exist after patch")
}
if got := updated.Attributes["websockets"]; got != "false" {
t.Fatalf("attrs websockets = %q, want %q", got, "false")
}
if got, ok := updated.Metadata["websockets"].(bool); !ok || got {
t.Fatalf("metadata.websockets = %#v, want false", updated.Metadata["websockets"])
}
}
func TestPatchAuthFileFields_ArbitraryFieldsPersistToFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "generic.json"
filePath := filepath.Join(authDir, fileName)
store := fileauth.NewFileTokenStore()
store.SetBaseDir(authDir)
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: fileName,
FileName: fileName,
Provider: "codex",
Attributes: map[string]string{
"path": filePath,
},
Metadata: map[string]any{
"type": "codex",
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
body := `{"name":"generic.json","abc":true,"nested.cde":true,"fgh":{"ijk":true}}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
raw, errRead := os.ReadFile(filePath)
if errRead != nil {
t.Fatalf("failed to read updated auth file: %v", errRead)
}
var data map[string]any
if errUnmarshal := json.Unmarshal(raw, &data); errUnmarshal != nil {
t.Fatalf("failed to unmarshal updated auth file: %v", errUnmarshal)
}
if got := data["abc"]; got != true {
t.Fatalf("abc = %#v, want true", got)
}
nested, ok := data["nested"].(map[string]any)
if !ok {
t.Fatalf("nested = %#v, want object", data["nested"])
}
if got := nested["cde"]; got != true {
t.Fatalf("nested.cde = %#v, want true", got)
}
fgh, ok := data["fgh"].(map[string]any)
if !ok {
t.Fatalf("fgh = %#v, want object", data["fgh"])
}
if got := fgh["ijk"]; got != true {
t.Fatalf("fgh.ijk = %#v, want true", got)
}
}

View File

@@ -71,6 +71,62 @@ func TestListAuthFilesFromDisk_IncludesProjectID(t *testing.T) {
}
}
func TestListAuthFiles_IncludesWebsocketsFromManager(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "codex-user@example.com-pro.json"
filePath := filepath.Join(authDir, fileName)
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex","email":"user@example.com"}`), 0o600); errWrite != nil {
t.Fatalf("failed to write auth file: %v", errWrite)
}
manager := coreauth.NewManager(nil, nil, nil)
record := &coreauth.Auth{
ID: fileName,
FileName: fileName,
Provider: "codex",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"path": filePath,
"websockets": "true",
},
Metadata: map[string]any{
"type": "codex",
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
entry := firstAuthFileEntry(t, h)
if got := entry["websockets"]; got != true {
t.Fatalf("expected websockets true, got %#v", got)
}
}
func TestListAuthFilesFromDisk_IncludesWebsockets(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
filePath := filepath.Join(authDir, "codex-user@example.com-pro.json")
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex","email":"user@example.com","websockets":false}`), 0o600); errWrite != nil {
t.Fatalf("failed to write auth file: %v", errWrite)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
entry := firstAuthFileEntry(t, h)
if got := entry["websockets"]; got != false {
t.Fatalf("expected websockets false, got %#v", got)
}
}
func firstAuthFileEntry(t *testing.T, h *Handler) map[string]any {
t.Helper()