From 167edfec6ccd05c1d5f03bc355050d5ec57ef550 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 26 May 2026 00:49:36 +0800 Subject: [PATCH] feat(auth): add support for websockets in auth file parsing and patching - Introduced parsing logic to handle `websockets` field in auth files. - Extended `PatchAuthFileFields` to update `websockets` and arbitrary nested metadata fields. - Added tests to validate `websockets` parsing, updating, and persistence. --- .../api/handlers/management/auth_files.go | 483 +++++++++++++----- .../auth_files_patch_fields_test.go | 118 +++++ .../management/auth_files_project_id_test.go | 56 ++ 3 files changed, 522 insertions(+), 135 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 291f6ef1e..c32f41a71 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -352,6 +352,18 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) { fileData["note"] = trimmed } } + if wv := gjson.GetBytes(data, "websockets"); wv.Exists() { + switch wv.Type { + case gjson.True: + fileData["websockets"] = true + case gjson.False: + fileData["websockets"] = false + case gjson.String: + if parsed, errParse := strconv.ParseBool(strings.TrimSpace(wv.String())); errParse == nil { + fileData["websockets"] = parsed + } + } + } } files = append(files, fileData) @@ -472,9 +484,43 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { } } } + if websockets, ok := authWebsocketsValue(auth); ok { + entry["websockets"] = websockets + } return entry } +func authWebsocketsValue(auth *coreauth.Auth) (bool, bool) { + if auth == nil { + return false, false + } + if auth.Attributes != nil { + if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" { + parsed, errParse := strconv.ParseBool(raw) + if errParse == nil { + return parsed, true + } + } + } + if auth.Metadata == nil { + return false, false + } + raw, ok := auth.Metadata["websockets"] + if !ok || raw == nil { + return false, false + } + switch v := raw.(type) { + case bool: + return v, true + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(v)) + if errParse == nil { + return parsed, true + } + } + return false, false +} + func authProjectID(auth *coreauth.Auth) string { if auth == nil { return "" @@ -1150,31 +1196,37 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled}) } -// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file. +// PatchAuthFileFields updates arbitrary metadata fields of an auth file. func (h *Handler) PatchAuthFileFields(c *gin.Context) { if h.authManager == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) return } - var req struct { - Name string `json:"name"` - Prefix *string `json:"prefix"` - ProxyURL *string `json:"proxy_url"` - Headers map[string]string `json:"headers"` - Priority *int `json:"priority"` - Note *string `json:"note"` - } - if err := c.ShouldBindJSON(&req); err != nil { + var req map[string]json.RawMessage + decoder := json.NewDecoder(c.Request.Body) + decoder.UseNumber() + if err := decoder.Decode(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) return } - name := strings.TrimSpace(req.Name) + nameRaw, ok := req["name"] + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"}) + return + } + var nameValue string + if err := json.Unmarshal(nameRaw, &nameValue); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"}) + return + } + name := strings.TrimSpace(nameValue) if name == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } + delete(req, "name") ctx := c.Request.Context() @@ -1198,136 +1250,35 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { } changed := false - if req.Prefix != nil { - prefix := strings.TrimSpace(*req.Prefix) - targetAuth.Prefix = prefix + touchedRoots := make(map[string]struct{}, len(req)) + for key, rawValue := range req { + fieldPath := strings.TrimSpace(key) + if fieldPath == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "field name is required"}) + return + } + value, errDecode := decodeAuthFileFieldValue(rawValue) + if errDecode != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid field %s", fieldPath)}) + return + } if targetAuth.Metadata == nil { targetAuth.Metadata = make(map[string]any) } - if prefix == "" { - delete(targetAuth.Metadata, "prefix") - } else { - targetAuth.Metadata["prefix"] = prefix + + if fieldPath == "headers" { + applyAuthFileHeadersPatch(targetAuth, value) + } else if errSet := setAuthFileMetadataValue(targetAuth.Metadata, fieldPath, value); errSet != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": errSet.Error()}) + return + } + if root := rootAuthFileField(fieldPath); root != "" { + touchedRoots[root] = struct{}{} } changed = true } - if req.ProxyURL != nil { - proxyURL := strings.TrimSpace(*req.ProxyURL) - targetAuth.ProxyURL = proxyURL - if targetAuth.Metadata == nil { - targetAuth.Metadata = make(map[string]any) - } - if proxyURL == "" { - delete(targetAuth.Metadata, "proxy_url") - } else { - targetAuth.Metadata["proxy_url"] = proxyURL - } - changed = true - } - if len(req.Headers) > 0 { - existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata) - nextHeaders := make(map[string]string, len(existingHeaders)) - for k, v := range existingHeaders { - nextHeaders[k] = v - } - headerChanged := false - - for key, value := range req.Headers { - name := strings.TrimSpace(key) - if name == "" { - continue - } - val := strings.TrimSpace(value) - attrKey := "header:" + name - if val == "" { - if _, ok := nextHeaders[name]; ok { - delete(nextHeaders, name) - headerChanged = true - } - if targetAuth.Attributes != nil { - if _, ok := targetAuth.Attributes[attrKey]; ok { - headerChanged = true - } - } - continue - } - if prev, ok := nextHeaders[name]; !ok || prev != val { - headerChanged = true - } - nextHeaders[name] = val - if targetAuth.Attributes != nil { - if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val { - headerChanged = true - } - } else { - headerChanged = true - } - } - - if headerChanged { - if targetAuth.Metadata == nil { - targetAuth.Metadata = make(map[string]any) - } - if targetAuth.Attributes == nil { - targetAuth.Attributes = make(map[string]string) - } - - for key, value := range req.Headers { - name := strings.TrimSpace(key) - if name == "" { - continue - } - val := strings.TrimSpace(value) - attrKey := "header:" + name - if val == "" { - delete(nextHeaders, name) - delete(targetAuth.Attributes, attrKey) - continue - } - nextHeaders[name] = val - targetAuth.Attributes[attrKey] = val - } - - if len(nextHeaders) == 0 { - delete(targetAuth.Metadata, "headers") - } else { - metaHeaders := make(map[string]any, len(nextHeaders)) - for k, v := range nextHeaders { - metaHeaders[k] = v - } - targetAuth.Metadata["headers"] = metaHeaders - } - changed = true - } - } - if req.Priority != nil || req.Note != nil { - if targetAuth.Metadata == nil { - targetAuth.Metadata = make(map[string]any) - } - if targetAuth.Attributes == nil { - targetAuth.Attributes = make(map[string]string) - } - - if req.Priority != nil { - if *req.Priority == 0 { - delete(targetAuth.Metadata, "priority") - delete(targetAuth.Attributes, "priority") - } else { - targetAuth.Metadata["priority"] = *req.Priority - targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority) - } - } - if req.Note != nil { - trimmedNote := strings.TrimSpace(*req.Note) - if trimmedNote == "" { - delete(targetAuth.Metadata, "note") - delete(targetAuth.Attributes, "note") - } else { - targetAuth.Metadata["note"] = trimmedNote - targetAuth.Attributes["note"] = trimmedNote - } - } - changed = true + if changed { + syncAuthFileMetadataFields(targetAuth, touchedRoots) } if !changed { @@ -1345,6 +1296,268 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) } +func decodeAuthFileFieldValue(raw json.RawMessage) (any, error) { + decoder := json.NewDecoder(bytes.NewReader(raw)) + decoder.UseNumber() + var value any + if err := decoder.Decode(&value); err != nil { + return nil, err + } + return value, nil +} + +func rootAuthFileField(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if idx := strings.Index(path, "."); idx >= 0 { + return strings.TrimSpace(path[:idx]) + } + return path +} + +func setAuthFileMetadataValue(metadata map[string]any, path string, value any) error { + if metadata == nil { + return fmt.Errorf("metadata is nil") + } + parts := strings.Split(path, ".") + current := metadata + for i, rawPart := range parts { + part := strings.TrimSpace(rawPart) + if part == "" { + return fmt.Errorf("invalid field path: %s", path) + } + if i == len(parts)-1 { + current[part] = value + return nil + } + next, ok := current[part].(map[string]any) + if !ok { + next = make(map[string]any) + current[part] = next + } + current = next + } + return nil +} + +func applyAuthFileHeadersPatch(auth *coreauth.Auth, value any) { + if auth == nil { + return + } + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + headersPatch, ok := authFileHeadersStringMap(value) + if !ok { + auth.Metadata["headers"] = value + return + } + + existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(auth.Metadata) + nextHeaders := make(map[string]string, len(existingHeaders)) + for key, val := range existingHeaders { + nextHeaders[key] = val + } + for key, value := range headersPatch { + name := strings.TrimSpace(key) + if name == "" { + continue + } + val := strings.TrimSpace(value) + if val == "" { + delete(nextHeaders, name) + continue + } + nextHeaders[name] = val + } + + if len(nextHeaders) == 0 { + delete(auth.Metadata, "headers") + return + } + metaHeaders := make(map[string]any, len(nextHeaders)) + for key, value := range nextHeaders { + metaHeaders[key] = value + } + auth.Metadata["headers"] = metaHeaders +} + +func authFileHeadersStringMap(value any) (map[string]string, bool) { + switch typed := value.(type) { + case map[string]string: + return typed, true + case map[string]any: + out := make(map[string]string, len(typed)) + for key, rawValue := range typed { + value, ok := rawValue.(string) + if !ok { + return nil, false + } + out[key] = value + } + return out, true + default: + return nil, false + } +} + +func syncAuthFileMetadataFields(auth *coreauth.Auth, touchedRoots map[string]struct{}) { + if auth == nil || len(touchedRoots) == 0 { + return + } + if _, ok := touchedRoots["prefix"]; ok { + if prefix, okString := auth.Metadata["prefix"].(string); okString { + auth.Prefix = strings.TrimSpace(prefix) + } + } + if _, ok := touchedRoots["proxy_url"]; ok { + if proxyURL, okString := auth.Metadata["proxy_url"].(string); okString { + auth.ProxyURL = strings.TrimSpace(proxyURL) + } + } + if _, ok := touchedRoots["headers"]; ok { + syncAuthFileHeaderAttributes(auth) + } + if _, ok := touchedRoots["priority"]; ok { + syncAuthFilePriorityAttribute(auth) + } + if _, ok := touchedRoots["note"]; ok { + syncAuthFileNoteAttribute(auth) + } + if _, ok := touchedRoots["websockets"]; ok { + syncAuthFileWebsocketsAttribute(auth) + } + if _, ok := touchedRoots["disabled"]; ok { + syncAuthFileDisabledState(auth) + } +} + +func syncAuthFileHeaderAttributes(auth *coreauth.Auth) { + if auth == nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + for key := range auth.Attributes { + if strings.HasPrefix(key, "header:") { + delete(auth.Attributes, key) + } + } + for name, value := range coreauth.ExtractCustomHeadersFromMetadata(auth.Metadata) { + auth.Attributes["header:"+name] = value + } +} + +func syncAuthFilePriorityAttribute(auth *coreauth.Auth) { + if auth == nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + priority, ok := authFileIntValue(auth.Metadata["priority"]) + if !ok { + delete(auth.Attributes, "priority") + return + } + if priority == 0 { + delete(auth.Attributes, "priority") + return + } + auth.Attributes["priority"] = strconv.Itoa(priority) +} + +func authFileIntValue(value any) (int, bool) { + switch typed := value.(type) { + case int: + return typed, true + case int64: + return int(typed), true + case float64: + return int(typed), true + case json.Number: + if i, err := typed.Int64(); err == nil { + return int(i), true + } + case string: + if i, err := strconv.Atoi(strings.TrimSpace(typed)); err == nil { + return i, true + } + } + return 0, false +} + +func syncAuthFileNoteAttribute(auth *coreauth.Auth) { + if auth == nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + note, ok := auth.Metadata["note"].(string) + if !ok { + delete(auth.Attributes, "note") + return + } + note = strings.TrimSpace(note) + if note == "" { + delete(auth.Attributes, "note") + return + } + auth.Attributes["note"] = note +} + +func syncAuthFileWebsocketsAttribute(auth *coreauth.Auth) { + if auth == nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + websockets, ok := authFileBoolValue(auth.Metadata["websockets"]) + if !ok { + delete(auth.Attributes, "websockets") + return + } + auth.Attributes["websockets"] = strconv.FormatBool(websockets) +} + +func authFileBoolValue(value any) (bool, bool) { + switch typed := value.(type) { + case bool: + return typed, true + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(typed)) + if errParse == nil { + return parsed, true + } + } + return false, false +} + +func syncAuthFileDisabledState(auth *coreauth.Auth) { + if auth == nil { + return + } + disabled, ok := authFileBoolValue(auth.Metadata["disabled"]) + if !ok { + return + } + auth.Disabled = disabled + if disabled { + auth.Status = coreauth.StatusDisabled + if strings.TrimSpace(auth.StatusMessage) == "" { + auth.StatusMessage = "disabled via management API" + } + return + } + auth.Status = coreauth.StatusActive + auth.StatusMessage = "" +} + func (h *Handler) disableAuth(ctx context.Context, id string) { if h == nil || h.authManager == nil { return diff --git a/internal/api/handlers/management/auth_files_patch_fields_test.go b/internal/api/handlers/management/auth_files_patch_fields_test.go index 568700a0d..072e487ee 100644 --- a/internal/api/handlers/management/auth_files_patch_fields_test.go +++ b/internal/api/handlers/management/auth_files_patch_fields_test.go @@ -5,11 +5,14 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + fileauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) @@ -162,3 +165,118 @@ func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) { t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1") } } + +func TestPatchAuthFileFields_WebsocketsFalseIsUpdate(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + store := &memoryAuthStore{} + manager := coreauth.NewManager(store, nil, nil) + record := &coreauth.Auth{ + ID: "codex.json", + FileName: "codex.json", + Provider: "codex", + Attributes: map[string]string{ + "path": "/tmp/codex.json", + "websockets": "true", + }, + Metadata: map[string]any{ + "type": "codex", + "websockets": true, + }, + } + if _, errRegister := manager.Register(context.Background(), record); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + + body := `{"name":"codex.json","websockets":false}` + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + h.PatchAuthFileFields(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + updated, ok := manager.GetByID("codex.json") + if !ok || updated == nil { + t.Fatalf("expected auth record to exist after patch") + } + if got := updated.Attributes["websockets"]; got != "false" { + t.Fatalf("attrs websockets = %q, want %q", got, "false") + } + if got, ok := updated.Metadata["websockets"].(bool); !ok || got { + t.Fatalf("metadata.websockets = %#v, want false", updated.Metadata["websockets"]) + } +} + +func TestPatchAuthFileFields_ArbitraryFieldsPersistToFile(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + fileName := "generic.json" + filePath := filepath.Join(authDir, fileName) + store := fileauth.NewFileTokenStore() + store.SetBaseDir(authDir) + manager := coreauth.NewManager(store, nil, nil) + record := &coreauth.Auth{ + ID: fileName, + FileName: fileName, + Provider: "codex", + Attributes: map[string]string{ + "path": filePath, + }, + Metadata: map[string]any{ + "type": "codex", + }, + } + if _, errRegister := manager.Register(context.Background(), record); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + + body := `{"name":"generic.json","abc":true,"nested.cde":true,"fgh":{"ijk":true}}` + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + h.PatchAuthFileFields(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + raw, errRead := os.ReadFile(filePath) + if errRead != nil { + t.Fatalf("failed to read updated auth file: %v", errRead) + } + var data map[string]any + if errUnmarshal := json.Unmarshal(raw, &data); errUnmarshal != nil { + t.Fatalf("failed to unmarshal updated auth file: %v", errUnmarshal) + } + if got := data["abc"]; got != true { + t.Fatalf("abc = %#v, want true", got) + } + nested, ok := data["nested"].(map[string]any) + if !ok { + t.Fatalf("nested = %#v, want object", data["nested"]) + } + if got := nested["cde"]; got != true { + t.Fatalf("nested.cde = %#v, want true", got) + } + fgh, ok := data["fgh"].(map[string]any) + if !ok { + t.Fatalf("fgh = %#v, want object", data["fgh"]) + } + if got := fgh["ijk"]; got != true { + t.Fatalf("fgh.ijk = %#v, want true", got) + } +} diff --git a/internal/api/handlers/management/auth_files_project_id_test.go b/internal/api/handlers/management/auth_files_project_id_test.go index e9634f5ae..0c4629348 100644 --- a/internal/api/handlers/management/auth_files_project_id_test.go +++ b/internal/api/handlers/management/auth_files_project_id_test.go @@ -71,6 +71,62 @@ func TestListAuthFilesFromDisk_IncludesProjectID(t *testing.T) { } } +func TestListAuthFiles_IncludesWebsocketsFromManager(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + fileName := "codex-user@example.com-pro.json" + filePath := filepath.Join(authDir, fileName) + if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex","email":"user@example.com"}`), 0o600); errWrite != nil { + t.Fatalf("failed to write auth file: %v", errWrite) + } + + manager := coreauth.NewManager(nil, nil, nil) + record := &coreauth.Auth{ + ID: fileName, + FileName: fileName, + Provider: "codex", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "path": filePath, + "websockets": "true", + }, + Metadata: map[string]any{ + "type": "codex", + }, + } + if _, errRegister := manager.Register(context.Background(), record); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + + entry := firstAuthFileEntry(t, h) + if got := entry["websockets"]; got != true { + t.Fatalf("expected websockets true, got %#v", got) + } +} + +func TestListAuthFilesFromDisk_IncludesWebsockets(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + filePath := filepath.Join(authDir, "codex-user@example.com-pro.json") + if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex","email":"user@example.com","websockets":false}`), 0o600); errWrite != nil { + t.Fatalf("failed to write auth file: %v", errWrite) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil) + + entry := firstAuthFileEntry(t, h) + if got := entry["websockets"]; got != false { + t.Fatalf("expected websockets false, got %#v", got) + } +} + func firstAuthFileEntry(t *testing.T, h *Handler) map[string]any { t.Helper()