From e38ba28db52f37b2273b8a4b9edbdb1e7d191080 Mon Sep 17 00:00:00 2001 From: LTbinglingfeng Date: Fri, 12 Jun 2026 23:15:00 +0800 Subject: [PATCH] feat(pluginstore): add plugin store support --- config.example.yaml | 6 +- internal/api/handlers/management/handler.go | 31 +- .../api/handlers/management/plugin_store.go | 286 ++++++++++++++++++ .../handlers/management/plugin_store_test.go | 258 ++++++++++++++++ internal/api/server.go | 2 + internal/httpfetch/httpfetch.go | 62 ++++ internal/httpfetch/httpfetch_test.go | 67 ++++ internal/managementasset/updater.go | 51 +--- internal/pluginhost/adapters_test.go | 22 +- internal/pluginhost/command_line_test.go | 14 +- internal/pluginhost/host.go | 15 + internal/pluginhost/host_test.go | 49 +++ internal/pluginhost/platform.go | 5 + internal/pluginstore/checksum.go | 45 +++ internal/pluginstore/github.go | 130 ++++++++ internal/pluginstore/github_test.go | 93 ++++++ internal/pluginstore/install.go | 277 +++++++++++++++++ internal/pluginstore/install_test.go | 241 +++++++++++++++ internal/pluginstore/registry.go | 156 ++++++++++ internal/pluginstore/registry_test.go | 167 ++++++++++ internal/pluginstore/version.go | 69 +++++ internal/pluginstore/version_test.go | 34 +++ internal/thinking/validate.go | 2 +- sdk/cliproxy/auth/oauth_model_alias_test.go | 26 +- .../service_oauth_model_alias_test.go | 26 +- 25 files changed, 2031 insertions(+), 103 deletions(-) create mode 100644 internal/api/handlers/management/plugin_store.go create mode 100644 internal/api/handlers/management/plugin_store_test.go create mode 100644 internal/httpfetch/httpfetch.go create mode 100644 internal/httpfetch/httpfetch_test.go create mode 100644 internal/pluginstore/checksum.go create mode 100644 internal/pluginstore/github.go create mode 100644 internal/pluginstore/github_test.go create mode 100644 internal/pluginstore/install.go create mode 100644 internal/pluginstore/install_test.go create mode 100644 internal/pluginstore/registry.go create mode 100644 internal/pluginstore/registry_test.go create mode 100644 internal/pluginstore/version.go create mode 100644 internal/pluginstore/version_test.go diff --git a/config.example.yaml b/config.example.yaml index 98a3d7539..d8c97e873 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -392,9 +392,9 @@ nonstream-keepalive-interval: 0 # xai: # - name: "grok-4.3" # alias: "grok-latest" -# qoder: # plugin provider keys are supported for OAuth plugins -# - name: "qmodel_latest" -# alias: "qlatest" +# sample-provider: # plugin provider keys are supported for OAuth plugins +# - name: "sample-model-latest" +# alias: "sample-latest" # OAuth provider excluded models # oauth-excluded-models: diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 01e96f053..63d1edc86 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -16,6 +16,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/buildinfo" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginstore" sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" "golang.org/x/crypto/bcrypt" @@ -35,20 +36,22 @@ const attemptMaxIdleTime = 2 * time.Hour // Handler aggregates config reference, persistence path and helpers. type Handler struct { - cfg *config.Config - configFilePath string - mu sync.Mutex - attemptsMu sync.Mutex - failedAttempts map[string]*attemptInfo // keyed by client IP - authManager *coreauth.Manager - tokenStore coreauth.Store - localPassword string - allowRemoteOverride bool - envSecret string - logDir string - postAuthHook coreauth.PostAuthHook - postAuthPersistHook coreauth.PostAuthHook - pluginHost *pluginhost.Host + cfg *config.Config + configFilePath string + mu sync.Mutex + attemptsMu sync.Mutex + failedAttempts map[string]*attemptInfo // keyed by client IP + authManager *coreauth.Manager + tokenStore coreauth.Store + localPassword string + allowRemoteOverride bool + envSecret string + logDir string + postAuthHook coreauth.PostAuthHook + postAuthPersistHook coreauth.PostAuthHook + pluginHost *pluginhost.Host + pluginStoreRegistryURL string + pluginStoreHTTPClient pluginstore.HTTPDoer } // NewHandler creates a new management handler instance. diff --git a/internal/api/handlers/management/plugin_store.go b/internal/api/handlers/management/plugin_store.go new file mode 100644 index 000000000..9a84f271f --- /dev/null +++ b/internal/api/handlers/management/plugin_store.go @@ -0,0 +1,286 @@ +package management + +import ( + "errors" + "fmt" + "net/http" + "runtime" + "strings" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginstore" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" +) + +type pluginStoreListResponse struct { + PluginsEnabled bool `json:"plugins_enabled"` + PluginsDir string `json:"plugins_dir"` + Plugins []pluginStoreListEntry `json:"plugins"` +} + +type pluginStoreListEntry struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Author string `json:"author"` + Version string `json:"version"` + Repository string `json:"repository"` + Logo string `json:"logo,omitempty"` + Homepage string `json:"homepage,omitempty"` + License string `json:"license,omitempty"` + Tags []string `json:"tags,omitempty"` + Installed bool `json:"installed"` + InstalledVersion string `json:"installed_version"` + Path string `json:"path"` + Configured bool `json:"configured"` + Registered bool `json:"registered"` + Enabled bool `json:"enabled"` + EffectiveEnabled bool `json:"effective_enabled"` + UpdateAvailable bool `json:"update_available"` +} + +type pluginInstallResponse struct { + Status string `json:"status"` + ID string `json:"id"` + Version string `json:"version"` + Path string `json:"path"` + PluginsEnabled bool `json:"plugins_enabled"` + RestartRequired bool `json:"restart_required"` +} + +type pluginLocalStatus struct { + Installed bool + InstalledVersion string + Path string + Configured bool + Registered bool + Enabled bool + EffectiveEnabled bool +} + +func (h *Handler) ListPluginStore(c *gin.Context) { + pluginsEnabled, pluginsDir, proxyURL, configs, host := h.pluginStoreSnapshot() + client := h.newPluginStoreClient(proxyURL) + registry, errRegistry := client.FetchRegistry(c.Request.Context()) + if errRegistry != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_store_registry_failed", "message": errRegistry.Error()}) + return + } + statuses, errStatus := pluginLocalStatuses(pluginsEnabled, pluginsDir, configs, host) + if errStatus != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "plugin_discovery_failed", "message": errStatus.Error()}) + return + } + + entries := make([]pluginStoreListEntry, 0, len(registry.Plugins)) + for _, plugin := range registry.Plugins { + status := statuses[plugin.ID] + installedVersion := status.InstalledVersion + entries = append(entries, pluginStoreListEntry{ + ID: plugin.ID, + Name: plugin.Name, + Description: plugin.Description, + Author: plugin.Author, + Version: plugin.Version, + Repository: plugin.Repository, + Logo: plugin.Logo, + Homepage: plugin.Homepage, + License: plugin.License, + Tags: append([]string{}, plugin.Tags...), + Installed: status.Installed, + InstalledVersion: installedVersion, + Path: status.Path, + Configured: status.Configured, + Registered: status.Registered, + Enabled: status.Enabled, + EffectiveEnabled: status.EffectiveEnabled, + UpdateAvailable: pluginstore.UpdateAvailable(installedVersion, plugin.Version), + }) + } + + c.JSON(http.StatusOK, pluginStoreListResponse{ + PluginsEnabled: pluginsEnabled, + PluginsDir: pluginsDir, + Plugins: entries, + }) +} + +func (h *Handler) InstallPluginFromStore(c *gin.Context) { + h.installPluginFromStore(c, runtime.GOOS, runtime.GOARCH) +} + +func (h *Handler) installPluginFromStore(c *gin.Context, goos, goarch string) { + id, okID := pluginIDFromRequest(c) + if !okID { + return + } + pluginsEnabled, pluginsDir, proxyURL, _, host := h.pluginStoreSnapshot() + client := h.newPluginStoreClient(proxyURL) + registry, errRegistry := client.FetchRegistry(c.Request.Context()) + if errRegistry != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_store_registry_failed", "message": errRegistry.Error()}) + return + } + plugin, okPlugin := registry.PluginByID(id) + if !okPlugin { + c.JSON(http.StatusNotFound, gin.H{"error": "plugin_not_found", "message": "plugin not found in registry"}) + return + } + + pluginIsLoaded := func() bool { return pluginLoaded(host, id) } + result, errInstall := client.Install(c.Request.Context(), plugin, pluginstore.InstallOptions{ + PluginsDir: pluginsDir, + GOOS: goos, + GOARCH: goarch, + PluginLoaded: pluginIsLoaded, + }) + if errInstall != nil { + if errors.Is(errInstall, pluginstore.ErrLoadedPluginLocked) { + c.JSON(http.StatusConflict, gin.H{ + "error": "plugin_update_requires_restart", + "message": "loaded Windows plugins cannot be overwritten while the server is running", + "restart_required": true, + }) + return + } + c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_install_failed", "message": errInstall.Error()}) + return + } + // Sample after the install so the response reflects the library state at + // the time the new file landed on disk. + restartRequired := pluginIsLoaded() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "config_unavailable", + "message": fmt.Sprintf("plugin file installed at %s but config is unavailable to enable it", result.Path), + "path": result.Path, + }) + return + } + if errEnable := h.enablePluginConfigLocked(id); errEnable != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "config_update_failed", + "message": fmt.Sprintf("plugin file installed at %s but enabling it in config failed: %s", result.Path, errEnable.Error()), + "path": result.Path, + }) + return + } + if errSave := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); errSave != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "config_save_failed", + "message": fmt.Sprintf("plugin file installed at %s but saving config failed: %s", result.Path, errSave.Error()), + "path": result.Path, + }) + return + } + + c.JSON(http.StatusOK, pluginInstallResponse{ + Status: "installed", + ID: result.ID, + Version: result.Version, + Path: result.Path, + PluginsEnabled: pluginsEnabled, + RestartRequired: restartRequired, + }) +} + +// enablePluginConfigLocked sets plugins.configs..enabled to true while preserving +// the rest of the plugin's raw configuration. Callers must hold h.mu. +func (h *Handler) enablePluginConfigLocked(id string) error { + ensurePluginConfigMap(h.cfg) + node := pluginConfigNode(h.cfg.Plugins.Configs[id]) + setYAMLMappingValue(node, "enabled", boolYAMLNode(true)) + updated, errConfig := pluginInstanceConfigFromNode(node) + if errConfig != nil { + return fmt.Errorf("decode plugin config: %w", errConfig) + } + h.cfg.Plugins.Configs[id] = updated + return nil +} + +func (h *Handler) pluginStoreSnapshot() (bool, string, string, map[string]config.PluginInstanceConfig, *pluginhost.Host) { + if h == nil || h.cfg == nil { + return false, "plugins", "", map[string]config.PluginInstanceConfig{}, nil + } + h.mu.Lock() + defer h.mu.Unlock() + pluginsEnabled := h.cfg.Plugins.Enabled + pluginsDir := normalizedPluginsDir(h.cfg.Plugins.Dir) + proxyURL := strings.TrimSpace(h.cfg.ProxyURL) + configs := make(map[string]config.PluginInstanceConfig, len(h.cfg.Plugins.Configs)) + for id, item := range h.cfg.Plugins.Configs { + configs[id] = item + } + return pluginsEnabled, pluginsDir, proxyURL, configs, h.pluginHost +} + +func (h *Handler) newPluginStoreClient(proxyURL string) pluginstore.Client { + registryURL := "" + var httpClient pluginstore.HTTPDoer + if h != nil { + registryURL = strings.TrimSpace(h.pluginStoreRegistryURL) + httpClient = h.pluginStoreHTTPClient + } + if registryURL == "" { + registryURL = pluginstore.DefaultRegistryURL + } + if httpClient != nil { + return pluginstore.Client{HTTPClient: httpClient, RegistryURL: registryURL} + } + client := &http.Client{} + if strings.TrimSpace(proxyURL) != "" { + util.SetProxy(&sdkconfig.SDKConfig{ProxyURL: strings.TrimSpace(proxyURL)}, client) + } + return pluginstore.Client{HTTPClient: client, RegistryURL: registryURL} +} + +func pluginLocalStatuses(pluginsEnabled bool, pluginsDir string, configs map[string]config.PluginInstanceConfig, host *pluginhost.Host) (map[string]pluginLocalStatus, error) { + statuses := map[string]pluginLocalStatus{} + files, errDiscover := pluginhost.DiscoverPluginFiles(pluginsDir) + if errDiscover != nil { + return nil, errDiscover + } + for _, file := range files { + status := statuses[file.ID] + status.Installed = true + status.Path = file.Path + status.Enabled = true + statuses[file.ID] = status + } + for id, item := range configs { + status := statuses[id] + status.Configured = true + status.Enabled = pluginInstanceEnabled(item) + statuses[id] = status + } + if host != nil { + for _, info := range host.RegisteredPlugins() { + status := statuses[info.ID] + status.Installed = true + status.Registered = true + status.InstalledVersion = strings.TrimSpace(info.Metadata.Version) + if _, configured := configs[info.ID]; !configured && !status.Enabled { + status.Enabled = true + } + statuses[info.ID] = status + } + } + for id, status := range statuses { + status.EffectiveEnabled = pluginsEnabled && status.Enabled && status.Registered + statuses[id] = status + } + return statuses, nil +} + +func pluginLoaded(host *pluginhost.Host, id string) bool { + if host == nil { + return false + } + return host.PluginLoaded(id) +} diff --git a/internal/api/handlers/management/plugin_store_test.go b/internal/api/handlers/management/plugin_store_test.go new file mode 100644 index 000000000..a6ec621a5 --- /dev/null +++ b/internal/api/handlers/management/plugin_store_test.go @@ -0,0 +1,258 @@ +package management + +import ( + "archive/zip" + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" +) + +func TestListPluginStoreMergesInstalledStatus(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + pluginsDir := writeManagementPluginFile(t, "sample-provider") + h := &Handler{ + cfg: &config.Config{ + Plugins: config.PluginsConfig{ + Enabled: true, + Dir: pluginsDir, + Configs: map[string]config.PluginInstanceConfig{ + "sample-provider": pluginConfigFromYAML(t, "enabled: true\nmode: fast\n"), + }, + }, + }, + configFilePath: writeTestConfigFile(t), + pluginStoreRegistryURL: "https://registry.example/registry.json", + pluginStoreHTTPClient: fakePluginStoreHTTPClient{ + "https://registry.example/registry.json": registryJSON(t), + }, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v0/management/plugin-store", nil) + + h.ListPluginStore(c) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + var body pluginStoreListResponse + if errDecode := json.Unmarshal(rec.Body.Bytes(), &body); errDecode != nil { + t.Fatalf("Unmarshal() error = %v; body=%s", errDecode, rec.Body.String()) + } + if !body.PluginsEnabled { + t.Fatal("plugins_enabled = false, want true") + } + if len(body.Plugins) != 1 { + t.Fatalf("plugins len = %d, want 1", len(body.Plugins)) + } + entry := body.Plugins[0] + if !entry.Installed || !entry.Configured || !entry.Enabled { + t.Fatalf("store entry status = %#v, want installed configured enabled", entry) + } + if entry.Registered || entry.EffectiveEnabled { + t.Fatalf("runtime status = registered %v effective %v, want false false", entry.Registered, entry.EffectiveEnabled) + } + if entry.InstalledVersion != "" { + t.Fatalf("installed_version = %q, want empty for unregistered plugin", entry.InstalledVersion) + } + if entry.UpdateAvailable { + t.Fatal("update_available = true, want false when installed version is unknown") + } + if entry.Path == "" { + t.Fatal("path is empty") + } +} + +func TestInstallPluginFromStoreWritesFileAndEnablesConfig(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + pluginsDir := t.TempDir() + archiveData := makeManagementPluginStoreZip(t, "sample-provider"+managementPluginExtension(runtime.GOOS), "library-data") + archiveName := "sample-provider_0.1.0_" + runtime.GOOS + "_" + runtime.GOARCH + ".zip" + checksum := sha256.Sum256(archiveData) + h := &Handler{ + cfg: &config.Config{ + Plugins: config.PluginsConfig{ + Enabled: false, + Dir: pluginsDir, + Configs: map[string]config.PluginInstanceConfig{ + "sample-provider": pluginConfigFromYAML(t, "enabled: false\nmode: fast\n"), + }, + }, + }, + configFilePath: writeTestConfigFile(t), + pluginStoreRegistryURL: "https://registry.example/registry.json", + pluginStoreHTTPClient: fakePluginStoreHTTPClient{ + "https://registry.example/registry.json": registryJSON(t), + "https://api.github.com/repos/author-name/cliproxy-sample-provider-plugin/releases/tags/v0.1.0": []byte(`{ + "tag_name": "v0.1.0", + "assets": [ + {"name": "` + archiveName + `", "browser_download_url": "https://downloads.example/` + archiveName + `"}, + {"name": "checksums.txt", "browser_download_url": "https://downloads.example/checksums.txt"} + ] + }`), + "https://downloads.example/" + archiveName: archiveData, + "https://downloads.example/checksums.txt": []byte(hex.EncodeToString(checksum[:]) + " " + archiveName + "\n"), + }, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Params = gin.Params{{Key: "id", Value: "sample-provider"}} + c.Request = httptest.NewRequest(http.MethodPost, "/v0/management/plugin-store/sample-provider/install", nil) + + h.InstallPluginFromStore(c) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + var body pluginInstallResponse + if errDecode := json.Unmarshal(rec.Body.Bytes(), &body); errDecode != nil { + t.Fatalf("Unmarshal() error = %v; body=%s", errDecode, rec.Body.String()) + } + if body.Status != "installed" || body.ID != "sample-provider" || body.Version != "0.1.0" { + t.Fatalf("install response = %#v", body) + } + if body.PluginsEnabled { + t.Fatal("plugins_enabled = true, want false") + } + if body.RestartRequired { + t.Fatal("restart_required = true, want false") + } + targetPath := filepath.Join(pluginsDir, runtime.GOOS, runtime.GOARCH, "sample-provider"+managementPluginExtension(runtime.GOOS)) + data, errRead := os.ReadFile(targetPath) + if errRead != nil { + t.Fatalf("ReadFile(%s) error = %v", targetPath, errRead) + } + if string(data) != "library-data" { + t.Fatalf("installed file = %q, want library-data", data) + } + item := h.cfg.Plugins.Configs["sample-provider"] + if item.Enabled == nil || !*item.Enabled { + t.Fatalf("plugin enabled = %#v, want true", item.Enabled) + } + if h.cfg.Plugins.Enabled { + t.Fatal("global plugins.enabled changed to true") + } + raw := marshalPluginRaw(t, item) + if !strings.Contains(raw, "mode: fast") { + t.Fatalf("plugin raw config lost custom field:\n%s", raw) + } +} + +func TestEnablePluginConfigLockedPreservesExistingFields(t *testing.T) { + t.Parallel() + + h := &Handler{ + cfg: &config.Config{ + Plugins: config.PluginsConfig{ + Enabled: false, + Configs: map[string]config.PluginInstanceConfig{ + "sample-provider": pluginConfigFromYAML(t, "enabled: false\npriority: 5\nmode: fast\n"), + }, + }, + }, + } + + if errEnable := h.enablePluginConfigLocked("sample-provider"); errEnable != nil { + t.Fatalf("enablePluginConfigLocked() error = %v", errEnable) + } + if h.cfg.Plugins.Enabled { + t.Fatal("global Plugins.Enabled changed to true") + } + item := h.cfg.Plugins.Configs["sample-provider"] + if item.Enabled == nil || !*item.Enabled { + t.Fatalf("plugin enabled = %#v, want true", item.Enabled) + } + if item.Priority != 5 { + t.Fatalf("plugin priority = %d, want 5", item.Priority) + } + raw := marshalPluginRaw(t, item) + if !strings.Contains(raw, "mode: fast") { + t.Fatalf("plugin raw config lost custom field:\n%s", raw) + } +} + +func TestEnablePluginConfigLockedCreatesMissingConfig(t *testing.T) { + t.Parallel() + + h := &Handler{cfg: &config.Config{}} + if errEnable := h.enablePluginConfigLocked("sample-provider"); errEnable != nil { + t.Fatalf("enablePluginConfigLocked() error = %v", errEnable) + } + item := h.cfg.Plugins.Configs["sample-provider"] + if item.Enabled == nil || !*item.Enabled { + t.Fatalf("plugin enabled = %#v, want true", item.Enabled) + } +} + +type fakePluginStoreHTTPClient map[string][]byte + +func (c fakePluginStoreHTTPClient) Do(req *http.Request) (*http.Response, error) { + body, ok := c[req.URL.String()] + if !ok { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + Request: req, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(body)), + Header: make(http.Header), + Request: req, + }, nil +} + +func registryJSON(t *testing.T) []byte { + t.Helper() + + return []byte(`{ + "schema_version": 1, + "plugins": [{ + "id": "sample-provider", + "name": "Sample Provider", + "description": "Adds sample provider support.", + "author": "author-name", + "version": "0.1.0", + "repository": "https://github.com/author-name/cliproxy-sample-provider-plugin", + "tags": ["provider"] + }] + }`) +} + +func makeManagementPluginStoreZip(t *testing.T, name string, content string) []byte { + t.Helper() + + var buffer bytes.Buffer + writer := zip.NewWriter(&buffer) + file, errCreate := writer.Create(name) + if errCreate != nil { + t.Fatalf("Create(%s) error = %v", name, errCreate) + } + if _, errWrite := file.Write([]byte(content)); errWrite != nil { + t.Fatalf("Write(%s) error = %v", name, errWrite) + } + if errClose := writer.Close(); errClose != nil { + t.Fatalf("Close() error = %v", errClose) + } + return buffer.Bytes() +} diff --git a/internal/api/server.go b/internal/api/server.go index 0c27bcb16..dc939b524 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -603,6 +603,8 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML) mgmt.GET("/latest-version", s.mgmt.GetLatestVersion) mgmt.GET("/plugins", s.mgmt.ListPlugins) + mgmt.GET("/plugin-store", s.mgmt.ListPluginStore) + mgmt.POST("/plugin-store/:id/install", s.mgmt.InstallPluginFromStore) mgmt.PATCH("/plugins/:id/enabled", s.mgmt.PatchPluginEnabled) mgmt.PUT("/plugins/:id/config", s.mgmt.PutPluginConfig) mgmt.PATCH("/plugins/:id/config", s.mgmt.PatchPluginConfig) diff --git a/internal/httpfetch/httpfetch.go b/internal/httpfetch/httpfetch.go new file mode 100644 index 000000000..ce2bcb185 --- /dev/null +++ b/internal/httpfetch/httpfetch.go @@ -0,0 +1,62 @@ +package httpfetch + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +// Doer abstracts the HTTP client used to execute requests. +type Doer interface { + Do(*http.Request) (*http.Response, error) +} + +// GetBytes performs a GET request with the supplied headers, requires a +// success status, and returns the response body. When maxSize is positive +// the body is rejected once it exceeds maxSize bytes. +func GetBytes(ctx context.Context, client Doer, requestURL string, headers map[string]string, maxSize int64) ([]byte, error) { + if client == nil { + client = http.DefaultClient + } + req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if errRequest != nil { + return nil, fmt.Errorf("create request: %w", errRequest) + } + for key, value := range headers { + if value != "" { + req.Header.Set(key, value) + } + } + + resp, errDo := client.Do(req) + if errDo != nil { + return nil, fmt.Errorf("request failed: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.WithError(errClose).Debug("failed to close response body") + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + reader := io.Reader(resp.Body) + if maxSize > 0 { + reader = io.LimitReader(resp.Body, maxSize+1) + } + data, errRead := io.ReadAll(reader) + if errRead != nil { + return nil, fmt.Errorf("read response: %w", errRead) + } + if maxSize > 0 && int64(len(data)) > maxSize { + return nil, fmt.Errorf("response exceeds maximum allowed size of %d bytes", maxSize) + } + return data, nil +} diff --git a/internal/httpfetch/httpfetch_test.go b/internal/httpfetch/httpfetch_test.go new file mode 100644 index 000000000..227e43817 --- /dev/null +++ b/internal/httpfetch/httpfetch_test.go @@ -0,0 +1,67 @@ +package httpfetch + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGetBytesReturnsBodyAndSendsHeaders(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("User-Agent") != "agent" || r.Header.Get("Accept") != "application/json" { + http.Error(w, "missing headers", http.StatusBadRequest) + return + } + _, _ = w.Write([]byte("payload")) + })) + t.Cleanup(server.Close) + + data, errGet := GetBytes(context.Background(), server.Client(), server.URL, map[string]string{ + "User-Agent": "agent", + "Accept": "application/json", + }, 0) + if errGet != nil { + t.Fatalf("GetBytes() error = %v", errGet) + } + if string(data) != "payload" { + t.Fatalf("GetBytes() = %q, want payload", data) + } +} + +func TestGetBytesRejectsErrorStatus(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "missing", http.StatusNotFound) + })) + t.Cleanup(server.Close) + + _, errGet := GetBytes(context.Background(), server.Client(), server.URL, nil, 0) + if errGet == nil { + t.Fatal("GetBytes() error = nil") + } + if !strings.Contains(errGet.Error(), "unexpected status 404") { + t.Fatalf("GetBytes() error = %v, want status 404", errGet) + } +} + +func TestGetBytesEnforcesMaxSize(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("0123456789")) + })) + t.Cleanup(server.Close) + + _, errGet := GetBytes(context.Background(), server.Client(), server.URL, nil, 4) + if errGet == nil { + t.Fatal("GetBytes() error = nil") + } + if !strings.Contains(errGet.Error(), "maximum allowed size") { + t.Fatalf("GetBytes() error = %v, want size limit error", errGet) + } +} diff --git a/internal/managementasset/updater.go b/internal/managementasset/updater.go index 58499fa5a..b9f884106 100644 --- a/internal/managementasset/updater.go +++ b/internal/managementasset/updater.go @@ -18,6 +18,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/httpfetch" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" log "github.com/sirupsen/logrus" @@ -345,32 +346,22 @@ func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL strin releaseURL = defaultManagementReleaseURL } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil) - if err != nil { - return nil, "", fmt.Errorf("create release request: %w", err) + headers := map[string]string{ + "Accept": "application/vnd.github+json", + "User-Agent": httpUserAgent, } - req.Header.Set("Accept", "application/vnd.github+json") - req.Header.Set("User-Agent", httpUserAgent) gitURL := strings.ToLower(strings.TrimSpace(os.Getenv("GITSTORE_GIT_URL"))) if tok := strings.TrimSpace(os.Getenv("GITSTORE_GIT_TOKEN")); tok != "" && strings.Contains(gitURL, "github.com") { - req.Header.Set("Authorization", "Bearer "+tok) + headers["Authorization"] = "Bearer " + tok } - resp, err := client.Do(req) + data, err := httpfetch.GetBytes(ctx, client, releaseURL, headers, 0) if err != nil { - return nil, "", fmt.Errorf("execute release request: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - return nil, "", fmt.Errorf("unexpected release status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + return nil, "", fmt.Errorf("fetch release: %w", err) } var release releaseResponse - if err = json.NewDecoder(resp.Body).Decode(&release); err != nil { + if err = json.Unmarshal(data, &release); err != nil { return nil, "", fmt.Errorf("decode release response: %w", err) } @@ -390,31 +381,9 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string) return nil, "", fmt.Errorf("empty download url") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) + data, err := httpfetch.GetBytes(ctx, client, downloadURL, map[string]string{"User-Agent": httpUserAgent}, maxAssetDownloadSize) if err != nil { - return nil, "", fmt.Errorf("create download request: %w", err) - } - req.Header.Set("User-Agent", httpUserAgent) - - resp, err := client.Do(req) - if err != nil { - return nil, "", fmt.Errorf("execute download request: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - - data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1)) - if err != nil { - return nil, "", fmt.Errorf("read download body: %w", err) - } - if int64(len(data)) > maxAssetDownloadSize { - return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize) + return nil, "", fmt.Errorf("download asset: %w", err) } sum := sha256.Sum256(data) diff --git a/internal/pluginhost/adapters_test.go b/internal/pluginhost/adapters_test.go index 58aa75f3a..64de0ad18 100644 --- a/internal/pluginhost/adapters_test.go +++ b/internal/pluginhost/adapters_test.go @@ -623,25 +623,25 @@ func TestRegisterExecutorsOAuthScopeSkipsStaticModelClientButRegistersExecutor(t manager := newFakeExecutorManager() staticCalled := false host := newHostWithRecords(capabilityRecord{ - id: "qoder", + id: "sample-provider", plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{ - AuthProvider: fakeAuthProvider{identifier: "qoder"}, + AuthProvider: fakeAuthProvider{identifier: "sample-provider"}, ModelProvider: modelProviderFunc{ staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) { staticCalled = true return pluginapi.ModelResponse{ - Provider: "qoder", + Provider: "sample-provider", Models: []pluginapi.ModelInfo{{ID: "static-model"}}, }, nil }, modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) { return pluginapi.ModelResponse{ - Provider: "qoder", + Provider: "sample-provider", Models: []pluginapi.ModelInfo{{ID: "oauth-model"}}, }, nil }, }, - Executor: &fakeExecutor{identifier: "qoder"}, + Executor: &fakeExecutor{identifier: "sample-provider"}, ExecutorModelScope: pluginapi.ExecutorModelScopeOAuth, }}, }) @@ -652,21 +652,21 @@ func TestRegisterExecutorsOAuthScopeSkipsStaticModelClientButRegistersExecutor(t if staticCalled { t.Fatal("StaticModels was called for an OAuth-only executor") } - if _, okExecutor := manager.executors["qoder"]; !okExecutor { + if _, okExecutor := manager.executors["sample-provider"]; !okExecutor { t.Fatal("OAuth-only executor was not registered") } - if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("qoder", "qoder")]; okClient { + if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("sample-provider", "sample-provider")]; okClient { t.Fatal("OAuth-only executor registered a static model client") } - if got := host.ModelsForProvider("qoder"); len(got) != 0 { + if got := host.ModelsForProvider("sample-provider"); len(got) != 0 { t.Fatalf("OAuth-only provider models = %#v, want none", got) } result := host.ModelsForAuth(context.Background(), &coreauth.Auth{ - ID: "qoder-auth", - Provider: "qoder", + ID: "sample-provider-auth", + Provider: "sample-provider", }) - if !result.Handled || result.Provider != "qoder" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" { + if !result.Handled || result.Provider != "sample-provider" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" { t.Fatalf("OAuth model result = %#v, want oauth-model", result) } } diff --git a/internal/pluginhost/command_line_test.go b/internal/pluginhost/command_line_test.go index 93f05024b..a0d3e25d1 100644 --- a/internal/pluginhost/command_line_test.go +++ b/internal/pluginhost/command_line_test.go @@ -127,16 +127,16 @@ func TestExecuteCommandLinePersistsReturnedAuths(t *testing.T) { response: pluginapi.CommandLineExecutionResponse{ Stdout: []byte("login ok\n"), Auths: []pluginapi.AuthData{{ - Provider: "Qoder", - ID: "qoder.json", - FileName: "qoder.json", + Provider: "Sample-Provider", + ID: "sample-provider.json", + FileName: "sample-provider.json", Label: "Luis", StorageJSON: []byte(`{"token":"secret"}`), }}, }, } host := newHostWithRecords(capabilityRecord{ - id: "qoder", + id: "sample-provider", plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{CommandLinePlugin: plugin}}, }) host.runtimeConfig = &config.Config{AuthDir: authDir} @@ -160,13 +160,13 @@ func TestExecuteCommandLinePersistsReturnedAuths(t *testing.T) { t.Fatalf("saved auths = %d, want 1", len(store.saved)) } saved := store.saved[0] - if saved.Provider != "qoder" || saved.ID != "qoder.json" || saved.FileName != "qoder.json" { - t.Fatalf("saved auth = %#v, want normalized qoder auth", saved) + if saved.Provider != "sample-provider" || saved.ID != "sample-provider.json" || saved.FileName != "sample-provider.json" { + t.Fatalf("saved auth = %#v, want normalized sample provider auth", saved) } if saved.Storage == nil { t.Fatal("saved auth storage = nil, want plugin token storage") } - if store.paths[0] != filepath.Join(authDir, "qoder.json") { + if store.paths[0] != filepath.Join(authDir, "sample-provider.json") { t.Fatalf("saved path = %q, want auth dir path", store.paths[0]) } } diff --git a/internal/pluginhost/host.go b/internal/pluginhost/host.go index ffa596ad5..6f563b63b 100644 --- a/internal/pluginhost/host.go +++ b/internal/pluginhost/host.go @@ -114,6 +114,21 @@ func (h *Host) Snapshot() *Snapshot { return emptySnapshot() } +// PluginLoaded reports whether a plugin dynamic library is still loaded by the host. +func (h *Host) PluginLoaded(id string) bool { + if h == nil { + return false + } + id = strings.TrimSpace(id) + if id == "" { + return false + } + h.mu.Lock() + defer h.mu.Unlock() + _, ok := h.loaded[id] + return ok +} + func (h *Host) ApplyConfig(ctx context.Context, cfg *config.Config) { if h == nil { return diff --git a/internal/pluginhost/host_test.go b/internal/pluginhost/host_test.go index 78354a5f1..0075204a8 100644 --- a/internal/pluginhost/host_test.go +++ b/internal/pluginhost/host_test.go @@ -64,6 +64,55 @@ func TestHostApplyConfig_DisabledPluginSkipsCapability(t *testing.T) { } } +func TestPluginLoadedTracksLoadedPluginAfterDisabled(t *testing.T) { + disabled := false + loader := newTestSymbolLoader() + plugin := &testPlugin{ + registerResult: validTestPlugin("alpha"), + reconfigureResult: validTestPlugin("alpha"), + } + loader.lookups["alpha"] = newTestSymbolLookup(plugin) + h := NewForTest(loader) + t.Cleanup(h.ShutdownAll) + pluginsDir := makePluginDir(t, "alpha") + + h.ApplyConfig(context.Background(), &config.Config{ + Plugins: config.PluginsConfig{ + Enabled: true, + Dir: pluginsDir, + }, + }) + + if !h.PluginLoaded("alpha") { + t.Fatal("PluginLoaded(alpha) = false, want true after load") + } + if len(h.RegisteredPlugins()) != 1 { + t.Fatalf("RegisteredPlugins() len = %d, want 1", len(h.RegisteredPlugins())) + } + + h.ApplyConfig(context.Background(), &config.Config{ + Plugins: config.PluginsConfig{ + Enabled: true, + Dir: pluginsDir, + Configs: map[string]config.PluginInstanceConfig{ + "alpha": {Enabled: &disabled}, + }, + }, + }) + + if len(h.RegisteredPlugins()) != 0 { + t.Fatalf("RegisteredPlugins() len = %d, want 0 after disable", len(h.RegisteredPlugins())) + } + if !h.PluginLoaded("alpha") { + t.Fatal("PluginLoaded(alpha) = false, want true while library remains loaded") + } + + h.ShutdownAll() + if h.PluginLoaded("alpha") { + t.Fatal("PluginLoaded(alpha) = true, want false after ShutdownAll") + } +} + func TestHostApplyConfigRegistersPluginThinkingApplier(t *testing.T) { loader := newTestSymbolLoader() plugin := &testPlugin{ diff --git a/internal/pluginhost/platform.go b/internal/pluginhost/platform.go index 4ea9b86e6..5926a96a5 100644 --- a/internal/pluginhost/platform.go +++ b/internal/pluginhost/platform.go @@ -44,6 +44,11 @@ func pluginIDFromPath(path string) string { return base } +// PluginExtension returns the dynamic library file extension used for goos. +func PluginExtension(goos string) string { + return pluginExtension(goos) +} + func pluginExtension(goos string) string { switch goos { case "darwin": diff --git a/internal/pluginstore/checksum.go b/internal/pluginstore/checksum.go new file mode 100644 index 000000000..fc248ea60 --- /dev/null +++ b/internal/pluginstore/checksum.go @@ -0,0 +1,45 @@ +package pluginstore + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" +) + +func ParseChecksums(data []byte) (map[string]string, error) { + out := map[string]string{} + for lineNumber, rawLine := range strings.Split(string(data), "\n") { + line := strings.TrimSpace(rawLine) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + return nil, fmt.Errorf("line %d: invalid checksum entry", lineNumber+1) + } + hash := strings.ToLower(strings.TrimSpace(fields[0])) + if len(hash) != sha256.Size*2 { + return nil, fmt.Errorf("line %d: invalid sha256 length", lineNumber+1) + } + if _, errDecode := hex.DecodeString(hash); errDecode != nil { + return nil, fmt.Errorf("line %d: invalid sha256: %w", lineNumber+1, errDecode) + } + name := strings.TrimPrefix(strings.TrimSpace(fields[1]), "*") + out[name] = hash + } + return out, nil +} + +func VerifyChecksum(name string, data []byte, checksums map[string]string) error { + expected := strings.ToLower(strings.TrimSpace(checksums[name])) + if expected == "" { + return fmt.Errorf("checksum for %s not found", name) + } + actualBytes := sha256.Sum256(data) + actual := hex.EncodeToString(actualBytes[:]) + if actual != expected { + return fmt.Errorf("checksum mismatch for %s", name) + } + return nil +} diff --git a/internal/pluginstore/github.go b/internal/pluginstore/github.go new file mode 100644 index 000000000..1132b1cab --- /dev/null +++ b/internal/pluginstore/github.go @@ -0,0 +1,130 @@ +package pluginstore + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/httpfetch" +) + +const userAgent = "CLIProxyAPI" + +// HTTPDoer abstracts the HTTP client used to execute requests. +type HTTPDoer = httpfetch.Doer + +type Client struct { + HTTPClient HTTPDoer + RegistryURL string + UserAgent string +} + +type Release struct { + TagName string `json:"tag_name"` + Assets []ReleaseAsset `json:"assets"` +} + +type ReleaseAsset struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` +} + +func (c Client) FetchRegistry(ctx context.Context) (Registry, error) { + registryURL := strings.TrimSpace(c.RegistryURL) + if registryURL == "" { + registryURL = DefaultRegistryURL + } + data, errDownload := c.get(ctx, registryURL, "application/json") + if errDownload != nil { + return Registry{}, errDownload + } + registry, errParse := ParseRegistry(data) + if errParse != nil { + return Registry{}, errParse + } + return registry, nil +} + +func (c Client) FetchRelease(ctx context.Context, plugin Plugin) (Release, error) { + owner, repo, errRepository := GitHubRepositoryParts(plugin.Repository) + if errRepository != nil { + return Release{}, errRepository + } + releaseURL := fmt.Sprintf( + "https://api.github.com/repos/%s/%s/releases/tags/%s", + url.PathEscape(owner), + url.PathEscape(repo), + url.PathEscape("v"+strings.TrimSpace(plugin.Version)), + ) + data, errDownload := c.get(ctx, releaseURL, "application/vnd.github+json") + if errDownload != nil { + return Release{}, errDownload + } + var release Release + if errDecode := json.Unmarshal(data, &release); errDecode != nil { + return Release{}, fmt.Errorf("decode release: %w", errDecode) + } + return release, nil +} + +func (c Client) DownloadAsset(ctx context.Context, asset ReleaseAsset) ([]byte, error) { + if strings.TrimSpace(asset.BrowserDownloadURL) == "" { + return nil, fmt.Errorf("asset %q missing browser_download_url", asset.Name) + } + return c.get(ctx, asset.BrowserDownloadURL, "application/octet-stream") +} + +func (c Client) get(ctx context.Context, requestURL string, accept string) ([]byte, error) { + return httpfetch.GetBytes(ctx, c.httpClient(), requestURL, map[string]string{ + "Accept": accept, + "User-Agent": c.userAgent(), + }, 0) +} + +func (c Client) httpClient() HTTPDoer { + if c.HTTPClient != nil { + return c.HTTPClient + } + return http.DefaultClient +} + +func (c Client) userAgent() string { + if strings.TrimSpace(c.UserAgent) != "" { + return strings.TrimSpace(c.UserAgent) + } + return userAgent +} + +func SelectReleaseAssets(release Release, id, version, goos, goarch string) (ReleaseAsset, ReleaseAsset, error) { + archiveName := ArchiveName(id, version, goos, goarch) + var archiveAsset ReleaseAsset + var checksumAsset ReleaseAsset + for _, asset := range release.Assets { + switch strings.TrimSpace(asset.Name) { + case archiveName: + archiveAsset = asset + case "checksums.txt": + checksumAsset = asset + } + } + if strings.TrimSpace(archiveAsset.Name) == "" { + return ReleaseAsset{}, ReleaseAsset{}, fmt.Errorf("release asset %s not found", archiveName) + } + if strings.TrimSpace(checksumAsset.Name) == "" { + return ReleaseAsset{}, ReleaseAsset{}, fmt.Errorf("release asset checksums.txt not found") + } + return archiveAsset, checksumAsset, nil +} + +func ArchiveName(id, version, goos, goarch string) string { + return fmt.Sprintf( + "%s_%s_%s_%s.zip", + strings.TrimSpace(id), + strings.TrimSpace(version), + strings.TrimSpace(goos), + strings.TrimSpace(goarch), + ) +} diff --git a/internal/pluginstore/github_test.go b/internal/pluginstore/github_test.go new file mode 100644 index 000000000..39b2c2f9a --- /dev/null +++ b/internal/pluginstore/github_test.go @@ -0,0 +1,93 @@ +package pluginstore + +import ( + "crypto/sha256" + "encoding/hex" + "strings" + "testing" +) + +func TestSelectReleaseAssets(t *testing.T) { + t.Parallel() + + release := Release{Assets: []ReleaseAsset{ + {Name: "sample-provider_0.1.0_darwin_arm64.zip", BrowserDownloadURL: "https://example.com/sample-provider.zip"}, + {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"}, + }} + archiveAsset, checksumAsset, errSelect := SelectReleaseAssets(release, "sample-provider", "0.1.0", "darwin", "arm64") + if errSelect != nil { + t.Fatalf("SelectReleaseAssets() error = %v", errSelect) + } + if archiveAsset.BrowserDownloadURL != "https://example.com/sample-provider.zip" { + t.Fatalf("archive URL = %q", archiveAsset.BrowserDownloadURL) + } + if checksumAsset.BrowserDownloadURL != "https://example.com/checksums.txt" { + t.Fatalf("checksum URL = %q", checksumAsset.BrowserDownloadURL) + } +} + +func TestSelectReleaseAssetsRejectsMissingAssets(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + release Release + wantErr string + }{ + { + name: "missing zip", + release: Release{Assets: []ReleaseAsset{ + {Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"}, + }}, + wantErr: "sample-provider_0.1.0_darwin_arm64.zip", + }, + { + name: "missing checksum", + release: Release{Assets: []ReleaseAsset{ + {Name: "sample-provider_0.1.0_darwin_arm64.zip", BrowserDownloadURL: "https://example.com/sample-provider.zip"}, + }}, + wantErr: "checksums.txt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, errSelect := SelectReleaseAssets(tt.release, "sample-provider", "0.1.0", "darwin", "arm64") + if errSelect == nil { + t.Fatal("SelectReleaseAssets() error = nil") + } + if !strings.Contains(errSelect.Error(), tt.wantErr) { + t.Fatalf("SelectReleaseAssets() error = %v, want substring %q", errSelect, tt.wantErr) + } + }) + } +} + +func TestParseChecksumsAndVerifyChecksum(t *testing.T) { + t.Parallel() + + data := []byte("zip-data") + sum := sha256.Sum256(data) + checksumText := hex.EncodeToString(sum[:]) + " sample-provider_0.1.0_darwin_arm64.zip\n" + checksums, errParse := ParseChecksums([]byte(checksumText)) + if errParse != nil { + t.Fatalf("ParseChecksums() error = %v", errParse) + } + if errVerify := VerifyChecksum("sample-provider_0.1.0_darwin_arm64.zip", data, checksums); errVerify != nil { + t.Fatalf("VerifyChecksum() error = %v", errVerify) + } +} + +func TestVerifyChecksumRejectsMissingAndMismatch(t *testing.T) { + t.Parallel() + + sum := sha256.Sum256([]byte("zip-data")) + checksums := map[string]string{"sample-provider.zip": hex.EncodeToString(sum[:])} + if errVerify := VerifyChecksum("missing.zip", []byte("zip-data"), checksums); errVerify == nil { + t.Fatal("VerifyChecksum() missing checksum error = nil") + } + if errVerify := VerifyChecksum("sample-provider.zip", []byte("other"), checksums); errVerify == nil { + t.Fatal("VerifyChecksum() mismatch error = nil") + } +} diff --git a/internal/pluginstore/install.go b/internal/pluginstore/install.go new file mode 100644 index 000000000..ef3e3e2cf --- /dev/null +++ b/internal/pluginstore/install.go @@ -0,0 +1,277 @@ +package pluginstore + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "runtime" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" + log "github.com/sirupsen/logrus" +) + +type InstallOptions struct { + PluginsDir string + GOOS string + GOARCH string + // PluginLoaded reports whether the plugin's dynamic library is currently + // loaded by the running host. Loaded libraries cannot be overwritten on + // Windows, so installs targeting Windows are rejected while it returns true. + PluginLoaded func() bool +} + +// ErrLoadedPluginLocked is returned when an install would overwrite a plugin +// library that is loaded by the running process on Windows. +var ErrLoadedPluginLocked = errors.New("loaded plugin library cannot be overwritten while the server is running") + +type InstallResult struct { + ID string `json:"id"` + Version string `json:"version"` + Path string `json:"path"` + Overwritten bool `json:"overwritten"` +} + +func (c Client) Install(ctx context.Context, plugin Plugin, options InstallOptions) (InstallResult, error) { + if errValidate := ValidatePlugin(plugin); errValidate != nil { + return InstallResult{}, errValidate + } + options = normalizeInstallOptions(options) + if loadedPluginInstallBlocked(options) { + return InstallResult{}, ErrLoadedPluginLocked + } + release, errRelease := c.FetchRelease(ctx, plugin) + if errRelease != nil { + return InstallResult{}, errRelease + } + archiveAsset, checksumAsset, errAssets := SelectReleaseAssets(release, plugin.ID, plugin.Version, options.GOOS, options.GOARCH) + if errAssets != nil { + return InstallResult{}, errAssets + } + archiveData, errArchive := c.DownloadAsset(ctx, archiveAsset) + if errArchive != nil { + return InstallResult{}, fmt.Errorf("download %s: %w", archiveAsset.Name, errArchive) + } + checksumData, errChecksum := c.DownloadAsset(ctx, checksumAsset) + if errChecksum != nil { + return InstallResult{}, fmt.Errorf("download checksums.txt: %w", errChecksum) + } + checksums, errParse := ParseChecksums(checksumData) + if errParse != nil { + return InstallResult{}, errParse + } + if errVerify := VerifyChecksum(archiveAsset.Name, archiveData, checksums); errVerify != nil { + return InstallResult{}, errVerify + } + return InstallArchive(archiveData, plugin, options) +} + +func InstallArchive(archiveData []byte, plugin Plugin, options InstallOptions) (InstallResult, error) { + options = normalizeInstallOptions(options) + id := strings.TrimSpace(plugin.ID) + if !pluginhost.ValidatePluginID(id) { + return InstallResult{}, fmt.Errorf("invalid plugin id %q", plugin.ID) + } + reader, errZip := zip.NewReader(bytes.NewReader(archiveData), int64(len(archiveData))) + if errZip != nil { + return InstallResult{}, fmt.Errorf("open zip: %w", errZip) + } + + libraryData, mode, errLibrary := readTargetLibrary(reader, id, options.GOOS) + if errLibrary != nil { + return InstallResult{}, errLibrary + } + + targetPath, errTarget := installTargetPath(options, id) + if errTarget != nil { + return InstallResult{}, errTarget + } + overwritten := false + if _, errStat := os.Stat(targetPath); errStat == nil { + overwritten = true + } else if !errors.Is(errStat, os.ErrNotExist) { + return InstallResult{}, fmt.Errorf("stat target plugin: %w", errStat) + } + // Re-check immediately before writing: the plugin may have been loaded + // while the archive was being downloaded and verified. + if loadedPluginInstallBlocked(options) { + return InstallResult{}, ErrLoadedPluginLocked + } + if errWrite := writeFileAtomic(targetPath, libraryData, mode); errWrite != nil { + return InstallResult{}, errWrite + } + return InstallResult{ + ID: id, + Version: strings.TrimSpace(plugin.Version), + Path: targetPath, + Overwritten: overwritten, + }, nil +} + +func installTargetPath(options InstallOptions, id string) (string, error) { + defaultPath := filepath.Join(options.PluginsDir, options.GOOS, options.GOARCH, id+pluginhost.PluginExtension(options.GOOS)) + if options.GOOS != runtime.GOOS || options.GOARCH != runtime.GOARCH { + return defaultPath, nil + } + files, errDiscover := pluginhost.DiscoverPluginFiles(options.PluginsDir) + if errDiscover != nil { + return "", fmt.Errorf("discover current plugin files: %w", errDiscover) + } + for _, file := range files { + if file.ID == id && strings.TrimSpace(file.Path) != "" { + return file.Path, nil + } + } + return defaultPath, nil +} + +func readTargetLibrary(reader *zip.Reader, id string, goos string) ([]byte, os.FileMode, error) { + targetName := strings.TrimSpace(id) + pluginhost.PluginExtension(goos) + var target *zip.File + for _, file := range reader.File { + cleanedName, errClean := cleanZipName(file.Name) + if errClean != nil { + return nil, 0, errClean + } + if file.FileInfo().IsDir() { + continue + } + if !regularZipFile(file) { + return nil, 0, fmt.Errorf("zip entry %s is not a regular file", file.Name) + } + if !hasDynamicLibraryExtension(cleanedName) { + continue + } + if cleanedName != targetName { + if path.Base(cleanedName) == targetName { + return nil, 0, fmt.Errorf("target dynamic library must be at zip root") + } + return nil, 0, fmt.Errorf("dynamic library filename must be %s", targetName) + } + if target != nil { + return nil, 0, fmt.Errorf("zip contains multiple target dynamic libraries") + } + target = file + } + if target == nil { + return nil, 0, fmt.Errorf("zip does not contain %s", targetName) + } + + handle, errOpen := target.Open() + if errOpen != nil { + return nil, 0, fmt.Errorf("open %s: %w", targetName, errOpen) + } + defer func() { + if errClose := handle.Close(); errClose != nil { + log.WithError(errClose).Debug("failed to close plugin archive entry") + } + }() + data, errRead := io.ReadAll(handle) + if errRead != nil { + return nil, 0, fmt.Errorf("read %s: %w", targetName, errRead) + } + mode := target.FileInfo().Mode().Perm() + if mode == 0 { + mode = 0o755 + } + return data, mode, nil +} + +func cleanZipName(name string) (string, error) { + if strings.TrimSpace(name) == "" { + return "", fmt.Errorf("zip entry has empty name") + } + if strings.Contains(name, `\`) { + return "", fmt.Errorf("zip entry %s uses backslash path separators", name) + } + if path.IsAbs(name) { + return "", fmt.Errorf("zip entry %s is absolute", name) + } + cleaned := path.Clean(name) + if cleaned == "." || cleaned == ".." || strings.HasPrefix(cleaned, "../") { + return "", fmt.Errorf("zip entry %s escapes archive root", name) + } + return cleaned, nil +} + +func regularZipFile(file *zip.File) bool { + mode := file.FileInfo().Mode() + return mode.IsRegular() || mode.Type() == 0 +} + +func hasDynamicLibraryExtension(name string) bool { + lowerName := strings.ToLower(name) + return strings.HasSuffix(lowerName, ".dylib") || strings.HasSuffix(lowerName, ".so") || strings.HasSuffix(lowerName, ".dll") +} + +func writeFileAtomic(targetPath string, data []byte, mode os.FileMode) error { + targetDir := filepath.Dir(targetPath) + if errMkdir := os.MkdirAll(targetDir, 0o755); errMkdir != nil { + return fmt.Errorf("create plugin directory: %w", errMkdir) + } + + temp, errTemp := os.CreateTemp(targetDir, "."+filepath.Base(targetPath)+".tmp-*") + if errTemp != nil { + return fmt.Errorf("create temp plugin file: %w", errTemp) + } + tempPath := temp.Name() + removeTemp := true + closed := false + defer func() { + if !closed { + if errClose := temp.Close(); errClose != nil { + log.WithError(errClose).Debug("failed to close temp plugin file") + } + } + if removeTemp { + if errRemove := os.Remove(tempPath); errRemove != nil && !errors.Is(errRemove, os.ErrNotExist) { + log.WithError(errRemove).Debug("failed to remove temp plugin file") + } + } + }() + + if errChmod := temp.Chmod(mode); errChmod != nil { + return fmt.Errorf("chmod temp plugin file: %w", errChmod) + } + if _, errWrite := temp.Write(data); errWrite != nil { + return fmt.Errorf("write temp plugin file: %w", errWrite) + } + if errSync := temp.Sync(); errSync != nil { + return fmt.Errorf("sync temp plugin file: %w", errSync) + } + if errClose := temp.Close(); errClose != nil { + return fmt.Errorf("close temp plugin file: %w", errClose) + } + closed = true + if errRename := os.Rename(tempPath, targetPath); errRename != nil { + return fmt.Errorf("install plugin file: %w", errRename) + } + removeTemp = false + return nil +} + +func loadedPluginInstallBlocked(options InstallOptions) bool { + return options.PluginLoaded != nil && strings.EqualFold(options.GOOS, "windows") && options.PluginLoaded() +} + +func normalizeInstallOptions(options InstallOptions) InstallOptions { + options.PluginsDir = strings.TrimSpace(options.PluginsDir) + if options.PluginsDir == "" { + options.PluginsDir = "plugins" + } + options.GOOS = strings.TrimSpace(options.GOOS) + if options.GOOS == "" { + options.GOOS = runtime.GOOS + } + options.GOARCH = strings.TrimSpace(options.GOARCH) + if options.GOARCH == "" { + options.GOARCH = runtime.GOARCH + } + return options +} diff --git a/internal/pluginstore/install_test.go b/internal/pluginstore/install_test.go new file mode 100644 index 000000000..aacd8103a --- /dev/null +++ b/internal/pluginstore/install_test.go @@ -0,0 +1,241 @@ +package pluginstore + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" +) + +func TestInstallBlocksLoadedWindowsPlugin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + goos string + loaded bool + wantBlocked bool + }{ + {name: "windows loaded", goos: "windows", loaded: true, wantBlocked: true}, + {name: "windows not loaded", goos: "windows", loaded: false, wantBlocked: false}, + {name: "linux loaded", goos: "linux", loaded: true, wantBlocked: false}, + {name: "darwin loaded", goos: "darwin", loaded: true, wantBlocked: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, errInstall := Client{HTTPClient: failingHTTPDoer{}}.Install(context.Background(), testPlugin(), InstallOptions{ + PluginsDir: t.TempDir(), + GOOS: tt.goos, + GOARCH: "amd64", + PluginLoaded: func() bool { return tt.loaded }, + }) + if errInstall == nil { + t.Fatal("Install() error = nil") + } + if gotBlocked := errors.Is(errInstall, ErrLoadedPluginLocked); gotBlocked != tt.wantBlocked { + t.Fatalf("Install() error = %v, blocked = %v, want %v", errInstall, gotBlocked, tt.wantBlocked) + } + }) + } +} + +func TestInstallArchiveBlocksLoadedWindowsPluginBeforeWrite(t *testing.T) { + t.Parallel() + + _, errInstall := InstallArchive(makeZip(t, map[string]string{ + "sample-provider.dll": "library-data", + }), testPlugin(), InstallOptions{ + PluginsDir: t.TempDir(), + GOOS: "windows", + GOARCH: "amd64", + PluginLoaded: func() bool { return true }, + }) + if !errors.Is(errInstall, ErrLoadedPluginLocked) { + t.Fatalf("InstallArchive() error = %v, want ErrLoadedPluginLocked", errInstall) + } +} + +func TestInstallArchiveWritesPlatformPlugin(t *testing.T) { + t.Parallel() + + root := t.TempDir() + result, errInstall := InstallArchive(makeZip(t, map[string]string{ + "README.md": "ignored", + "sample-provider.dylib": "library-data", + }), testPlugin(), InstallOptions{PluginsDir: root, GOOS: "darwin", GOARCH: "arm64"}) + if errInstall != nil { + t.Fatalf("InstallArchive() error = %v", errInstall) + } + wantPath := filepath.Join(root, "darwin", "arm64", "sample-provider.dylib") + if result.Path != wantPath { + t.Fatalf("Path = %q, want %q", result.Path, wantPath) + } + data, errRead := os.ReadFile(wantPath) + if errRead != nil { + t.Fatalf("ReadFile() error = %v", errRead) + } + if string(data) != "library-data" { + t.Fatalf("installed data = %q", data) + } +} + +func TestInstallArchiveReportsOverwrite(t *testing.T) { + t.Parallel() + + root := t.TempDir() + targetDir := filepath.Join(root, "darwin", "arm64") + if errMkdir := os.MkdirAll(targetDir, 0o755); errMkdir != nil { + t.Fatalf("MkdirAll() error = %v", errMkdir) + } + if errWrite := os.WriteFile(filepath.Join(targetDir, "sample-provider.dylib"), []byte("old"), 0o644); errWrite != nil { + t.Fatalf("WriteFile() error = %v", errWrite) + } + result, errInstall := InstallArchive(makeZip(t, map[string]string{ + "sample-provider.dylib": "new", + }), testPlugin(), InstallOptions{PluginsDir: root, GOOS: "darwin", GOARCH: "arm64"}) + if errInstall != nil { + t.Fatalf("InstallArchive() error = %v", errInstall) + } + if !result.Overwritten { + t.Fatal("Overwritten = false, want true") + } +} + +func TestInstallArchiveOverwritesRuntimeSelectedPlugin(t *testing.T) { + t.Parallel() + + root := t.TempDir() + existingPath := filepath.Join(root, "sample-provider"+pluginhost.PluginExtension(runtime.GOOS)) + if errWrite := os.WriteFile(existingPath, []byte("old"), 0o644); errWrite != nil { + t.Fatalf("WriteFile() error = %v", errWrite) + } + + result, errInstall := InstallArchive(makeZip(t, map[string]string{ + "sample-provider" + pluginhost.PluginExtension(runtime.GOOS): "new", + }), testPlugin(), InstallOptions{PluginsDir: root, GOOS: runtime.GOOS, GOARCH: runtime.GOARCH}) + if errInstall != nil { + t.Fatalf("InstallArchive() error = %v", errInstall) + } + if result.Path != existingPath { + t.Fatalf("Path = %q, want selected runtime plugin %q", result.Path, existingPath) + } + if !result.Overwritten { + t.Fatal("Overwritten = false, want true") + } + data, errRead := os.ReadFile(existingPath) + if errRead != nil { + t.Fatalf("ReadFile() error = %v", errRead) + } + if string(data) != "new" { + t.Fatalf("installed data = %q, want new", data) + } +} + +func TestInstallArchiveRejectsUnsafeArchives(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + files map[string]string + wantErr string + }{ + { + name: "zip slip", + files: map[string]string{"../sample-provider.dylib": "library"}, + wantErr: "escapes archive root", + }, + { + name: "absolute path", + files: map[string]string{"/sample-provider.dylib": "library"}, + wantErr: "is absolute", + }, + { + name: "nested target", + files: map[string]string{"nested/sample-provider.dylib": "library"}, + wantErr: "zip root", + }, + { + name: "extension mismatch", + files: map[string]string{"sample-provider.so": "library"}, + wantErr: "sample-provider.dylib", + }, + { + name: "filename mismatch", + files: map[string]string{"other.dylib": "library"}, + wantErr: "sample-provider.dylib", + }, + { + name: "missing target", + files: map[string]string{"README.md": "library"}, + wantErr: "does not contain", + }, + { + name: "multiple targets", + files: map[string]string{ + "sample-provider.dylib": "library", + "copy.dylib": "library", + }, + wantErr: "sample-provider.dylib", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, errInstall := InstallArchive(makeZip(t, tt.files), testPlugin(), InstallOptions{PluginsDir: t.TempDir(), GOOS: "darwin", GOARCH: "arm64"}) + if errInstall == nil { + t.Fatal("InstallArchive() error = nil") + } + if !strings.Contains(errInstall.Error(), tt.wantErr) { + t.Fatalf("InstallArchive() error = %v, want substring %q", errInstall, tt.wantErr) + } + }) + } +} + +func makeZip(t *testing.T, files map[string]string) []byte { + t.Helper() + + var buffer bytes.Buffer + writer := zip.NewWriter(&buffer) + for name, content := range files { + file, errCreate := writer.Create(name) + if errCreate != nil { + t.Fatalf("Create(%s) error = %v", name, errCreate) + } + if _, errWrite := file.Write([]byte(content)); errWrite != nil { + t.Fatalf("Write(%s) error = %v", name, errWrite) + } + } + if errClose := writer.Close(); errClose != nil { + t.Fatalf("Close() error = %v", errClose) + } + return buffer.Bytes() +} + +type failingHTTPDoer struct{} + +func (failingHTTPDoer) Do(*http.Request) (*http.Response, error) { + return nil, errors.New("network unavailable") +} + +func testPlugin() Plugin { + return Plugin{ + ID: "sample-provider", + Name: "Sample Provider", + Description: "Adds sample provider support.", + Author: "author-name", + Version: "0.1.0", + Repository: "https://github.com/author-name/cliproxy-sample-provider-plugin", + } +} diff --git a/internal/pluginstore/registry.go b/internal/pluginstore/registry.go new file mode 100644 index 000000000..6a20fabce --- /dev/null +++ b/internal/pluginstore/registry.go @@ -0,0 +1,156 @@ +package pluginstore + +import ( + "bytes" + "encoding/json" + "fmt" + "net/url" + "regexp" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" +) + +const ( + DefaultRegistryURL = "https://raw.githubusercontent.com/router-for-me/CLIProxyAPI-Plugins-Store/main/registry.json" + SchemaVersion = 1 +) + +var pluginVersionPattern = regexp.MustCompile(`^[0-9][0-9A-Za-z.+-]*$`) + +type Registry struct { + SchemaVersion int `json:"schema_version"` + Plugins []Plugin `json:"plugins"` +} + +type Plugin struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Author string `json:"author"` + Version string `json:"version"` + Repository string `json:"repository"` + Logo string `json:"logo,omitempty"` + Homepage string `json:"homepage,omitempty"` + License string `json:"license,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +func ParseRegistry(data []byte) (Registry, error) { + var registry Registry + decoder := json.NewDecoder(bytes.NewReader(data)) + if errDecode := decoder.Decode(®istry); errDecode != nil { + return Registry{}, fmt.Errorf("decode registry: %w", errDecode) + } + normalizeRegistry(®istry) + if errValidate := ValidateRegistry(registry); errValidate != nil { + return Registry{}, errValidate + } + return registry, nil +} + +func normalizeRegistry(registry *Registry) { + if registry == nil { + return + } + for index := range registry.Plugins { + plugin := ®istry.Plugins[index] + plugin.ID = strings.TrimSpace(plugin.ID) + plugin.Name = strings.TrimSpace(plugin.Name) + plugin.Description = strings.TrimSpace(plugin.Description) + plugin.Author = strings.TrimSpace(plugin.Author) + plugin.Version = strings.TrimSpace(plugin.Version) + plugin.Repository = strings.TrimSpace(plugin.Repository) + plugin.Logo = strings.TrimSpace(plugin.Logo) + plugin.Homepage = strings.TrimSpace(plugin.Homepage) + plugin.License = strings.TrimSpace(plugin.License) + for tagIndex := range plugin.Tags { + plugin.Tags[tagIndex] = strings.TrimSpace(plugin.Tags[tagIndex]) + } + } +} + +func ValidateRegistry(registry Registry) error { + if registry.SchemaVersion != SchemaVersion { + return fmt.Errorf("unsupported schema_version %d", registry.SchemaVersion) + } + seen := make(map[string]struct{}, len(registry.Plugins)) + for index, plugin := range registry.Plugins { + if errValidate := ValidatePlugin(plugin); errValidate != nil { + return fmt.Errorf("plugins[%d]: %w", index, errValidate) + } + id := strings.TrimSpace(plugin.ID) + if _, exists := seen[id]; exists { + return fmt.Errorf("plugins[%d]: duplicate plugin id %q", index, id) + } + seen[id] = struct{}{} + } + return nil +} + +func ValidatePlugin(plugin Plugin) error { + required := map[string]string{ + "id": plugin.ID, + "name": plugin.Name, + "description": plugin.Description, + "author": plugin.Author, + "version": plugin.Version, + "repository": plugin.Repository, + } + for field, value := range required { + if strings.TrimSpace(value) == "" { + return fmt.Errorf("missing required field %s", field) + } + } + if !pluginhost.ValidatePluginID(strings.TrimSpace(plugin.ID)) { + return fmt.Errorf("invalid plugin id %q", plugin.ID) + } + if !validPluginVersion(strings.TrimSpace(plugin.Version)) { + return fmt.Errorf("invalid plugin version %q", plugin.Version) + } + if _, _, errRepository := GitHubRepositoryParts(plugin.Repository); errRepository != nil { + return errRepository + } + return nil +} + +func validPluginVersion(version string) bool { + return version != "" && !strings.HasPrefix(version, "v") && pluginVersionPattern.MatchString(version) +} + +func GitHubRepositoryParts(repository string) (string, string, error) { + repository = strings.TrimSpace(repository) + parsed, errParse := url.Parse(repository) + if errParse != nil { + return "", "", fmt.Errorf("invalid repository URL: %w", errParse) + } + if parsed.Scheme != "https" || parsed.Host != "github.com" || parsed.RawQuery != "" || parsed.Fragment != "" { + return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}") + } + segments := strings.Split(strings.Trim(parsed.EscapedPath(), "/"), "/") + if len(segments) != 2 || segments[0] == "" || segments[1] == "" { + return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}") + } + owner, errOwner := url.PathUnescape(segments[0]) + if errOwner != nil { + return "", "", fmt.Errorf("invalid repository owner: %w", errOwner) + } + repo, errRepo := url.PathUnescape(segments[1]) + if errRepo != nil { + return "", "", fmt.Errorf("invalid repository name: %w", errRepo) + } + if strings.HasSuffix(repo, ".git") { + return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}") + } + return owner, repo, nil +} + +func (r Registry) PluginByID(id string) (Plugin, bool) { + id = strings.TrimSpace(id) + for _, plugin := range r.Plugins { + if strings.TrimSpace(plugin.ID) == id { + return plugin, true + } + } + return Plugin{}, false +} diff --git a/internal/pluginstore/registry_test.go b/internal/pluginstore/registry_test.go new file mode 100644 index 000000000..d8c89e8d6 --- /dev/null +++ b/internal/pluginstore/registry_test.go @@ -0,0 +1,167 @@ +package pluginstore + +import ( + "strings" + "testing" +) + +func TestParseRegistryValidatesRegistry(t *testing.T) { + t.Parallel() + + registry, errParse := ParseRegistry([]byte(`{ + "schema_version": 1, + "plugins": [{ + "id": "sample-provider", + "name": "Sample Provider", + "description": "Adds sample provider support.", + "author": "author-name", + "version": "0.1.0", + "repository": "https://github.com/author-name/cliproxy-sample-provider-plugin", + "logo": "https://example.com/logo.png", + "homepage": "https://github.com/author-name/cliproxy-sample-provider-plugin", + "license": "MIT", + "tags": ["provider"] + }] + }`)) + if errParse != nil { + t.Fatalf("ParseRegistry() error = %v", errParse) + } + plugin, ok := registry.PluginByID("sample-provider") + if !ok { + t.Fatal("PluginByID(sample-provider) missing") + } + if plugin.Version != "0.1.0" { + t.Fatalf("plugin version = %q, want 0.1.0", plugin.Version) + } +} + +func TestParseRegistryNormalizesPluginFields(t *testing.T) { + t.Parallel() + + registry, errParse := ParseRegistry([]byte(`{ + "schema_version": 1, + "plugins": [{ + "id": " sample-provider ", + "name": " Sample Provider ", + "description": " Adds sample provider support. ", + "author": " author-name ", + "version": " 0.1.0 ", + "repository": " https://github.com/author-name/cliproxy-sample-provider-plugin ", + "logo": " https://example.com/logo.png ", + "homepage": " https://github.com/author-name/cliproxy-sample-provider-plugin ", + "license": " MIT ", + "tags": [" provider "] + }] + }`)) + if errParse != nil { + t.Fatalf("ParseRegistry() error = %v", errParse) + } + plugin, ok := registry.PluginByID("sample-provider") + if !ok { + t.Fatal("PluginByID(sample-provider) missing") + } + if plugin.ID != "sample-provider" || plugin.Version != "0.1.0" || plugin.Repository != "https://github.com/author-name/cliproxy-sample-provider-plugin" { + t.Fatalf("plugin not normalized: %#v", plugin) + } + if plugin.Name != "Sample Provider" || plugin.Tags[0] != "provider" { + t.Fatalf("plugin display fields not normalized: %#v", plugin) + } +} + +func TestValidateRegistryRejectsInvalidEntries(t *testing.T) { + t.Parallel() + + valid := Plugin{ + ID: "sample-provider", + Name: "Sample Provider", + Description: "Adds sample provider support.", + Author: "author-name", + Version: "0.1.0", + Repository: "https://github.com/author-name/cliproxy-sample-provider-plugin", + } + tests := []struct { + name string + mutate func(*Registry) + wantErr string + }{ + { + name: "schema version", + mutate: func(registry *Registry) { + registry.SchemaVersion = 2 + }, + wantErr: "unsupported schema_version", + }, + { + name: "missing required field", + mutate: func(registry *Registry) { + registry.Plugins[0].Name = "" + }, + wantErr: "missing required field name", + }, + { + name: "duplicate id", + mutate: func(registry *Registry) { + registry.Plugins = append(registry.Plugins, valid) + }, + wantErr: "duplicate plugin id", + }, + { + name: "invalid id", + mutate: func(registry *Registry) { + registry.Plugins[0].ID = "../sample-provider" + }, + wantErr: "invalid plugin id", + }, + { + name: "v-prefixed version", + mutate: func(registry *Registry) { + registry.Plugins[0].Version = "v0.1.0" + }, + wantErr: "invalid plugin version", + }, + { + name: "invalid repository", + mutate: func(registry *Registry) { + registry.Plugins[0].Repository = "https://example.com/author/repo" + }, + wantErr: "repository must be", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + registry := Registry{SchemaVersion: 1, Plugins: []Plugin{valid}} + tt.mutate(®istry) + errValidate := ValidateRegistry(registry) + if errValidate == nil { + t.Fatal("ValidateRegistry() error = nil") + } + if !strings.Contains(errValidate.Error(), tt.wantErr) { + t.Fatalf("ValidateRegistry() error = %v, want substring %q", errValidate, tt.wantErr) + } + }) + } +} + +func TestGitHubRepositoryPartsRejectsNonRepositoryURLs(t *testing.T) { + t.Parallel() + + tests := []string{ + "http://github.com/owner/repo", + "https://github.com/owner", + "https://github.com/owner/repo/issues", + "https://github.com/owner/repo.git", + "https://github.com/owner/repo?tab=readme", + } + for _, repository := range tests { + t.Run(repository, func(t *testing.T) { + t.Parallel() + + if _, _, errParse := GitHubRepositoryParts(repository); errParse == nil { + t.Fatalf("GitHubRepositoryParts(%q) error = nil", repository) + } + }) + } +} diff --git a/internal/pluginstore/version.go b/internal/pluginstore/version.go new file mode 100644 index 000000000..4ad95d83e --- /dev/null +++ b/internal/pluginstore/version.go @@ -0,0 +1,69 @@ +package pluginstore + +import ( + "strconv" + "strings" +) + +// UpdateAvailable reports whether latest should be offered as an upgrade over +// installed. A leading "v"/"V" is ignored on both sides. Versions are compared +// numerically when both are dotted release numbers, so an installed version +// newer than the registry one is not reported as an update; otherwise any +// difference counts as an update. +func UpdateAvailable(installed, latest string) bool { + installed = normalizeVersion(installed) + latest = normalizeVersion(latest) + if installed == "" || latest == "" || installed == latest { + return false + } + comparison, comparable := compareVersions(installed, latest) + if !comparable { + return true + } + return comparison < 0 +} + +func normalizeVersion(version string) string { + version = strings.TrimSpace(version) + if len(version) > 1 && (version[0] == 'v' || version[0] == 'V') { + version = version[1:] + } + return version +} + +// compareVersions compares dotted numeric versions segment by segment, with +// missing segments treated as zero. It reports false when either version +// contains a non-numeric segment. +func compareVersions(a, b string) (int, bool) { + segmentsA := strings.Split(a, ".") + segmentsB := strings.Split(b, ".") + length := len(segmentsA) + if len(segmentsB) > length { + length = len(segmentsB) + } + for index := 0; index < length; index++ { + numberA, okA := versionSegment(segmentsA, index) + numberB, okB := versionSegment(segmentsB, index) + if !okA || !okB { + return 0, false + } + if numberA != numberB { + if numberA < numberB { + return -1, true + } + return 1, true + } + } + return 0, true +} + +func versionSegment(segments []string, index int) (int64, bool) { + if index >= len(segments) { + return 0, true + } + number, errParse := strconv.ParseInt(segments[index], 10, 64) + if errParse != nil || number < 0 { + return 0, false + } + return number, true +} diff --git a/internal/pluginstore/version_test.go b/internal/pluginstore/version_test.go new file mode 100644 index 000000000..e2a518560 --- /dev/null +++ b/internal/pluginstore/version_test.go @@ -0,0 +1,34 @@ +package pluginstore + +import "testing" + +func TestUpdateAvailable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + installed string + latest string + want bool + }{ + {name: "unknown installed", installed: "", latest: "0.2.0", want: false}, + {name: "same version", installed: "0.1.0", latest: "0.1.0", want: false}, + {name: "same version with v prefix", installed: "v0.1.0", latest: "0.1.0", want: false}, + {name: "newer registry version", installed: "0.1.0", latest: "0.2.0", want: true}, + {name: "newer registry version with v prefix", installed: "v0.1.0", latest: "0.2.0", want: true}, + {name: "numeric not lexicographic", installed: "0.1.9", latest: "0.1.10", want: true}, + {name: "installed newer than registry", installed: "0.2.0", latest: "0.1.0", want: false}, + {name: "missing segments treated as zero", installed: "0.1", latest: "0.1.0", want: false}, + {name: "prerelease falls back to inequality", installed: "0.1.0-rc1", latest: "0.1.0", want: true}, + {name: "non numeric falls back to inequality", installed: "dev", latest: "0.1.0", want: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := UpdateAvailable(tt.installed, tt.latest); got != tt.want { + t.Fatalf("UpdateAvailable(%q, %q) = %v, want %v", tt.installed, tt.latest, got, tt.want) + } + }) + } +} diff --git a/internal/thinking/validate.go b/internal/thinking/validate.go index 46038a698..909a2eeaa 100644 --- a/internal/thinking/validate.go +++ b/internal/thinking/validate.go @@ -339,7 +339,7 @@ func normalizeLevels(levels []string) []string { // These providers may also support level-based thinking (hybrid models). func isBudgetCapableProvider(provider string) bool { switch provider { - case "gemini", "gemini-cli", "antigravity", "claude", "qoder": + case "gemini", "gemini-cli", "antigravity", "claude": return true default: return false diff --git a/sdk/cliproxy/auth/oauth_model_alias_test.go b/sdk/cliproxy/auth/oauth_model_alias_test.go index 8e9f19420..7f6e2325d 100644 --- a/sdk/cliproxy/auth/oauth_model_alias_test.go +++ b/sdk/cliproxy/auth/oauth_model_alias_test.go @@ -175,10 +175,10 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) { func TestOAuthModelAliasChannel_PluginProvider(t *testing.T) { t.Parallel() - if got := OAuthModelAliasChannel(" Qoder ", "oauth"); got != "qoder" { - t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "qoder") + if got := OAuthModelAliasChannel(" Sample-Provider ", "oauth"); got != "sample-provider" { + t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "sample-provider") } - if got := OAuthModelAliasChannel("qoder", "api_key"); got != "" { + if got := OAuthModelAliasChannel("sample-provider", "api_key"); got != "" { t.Fatalf("OAuthModelAliasChannel() = %q, want empty channel for API key", got) } } @@ -206,18 +206,18 @@ func TestApplyOAuthModelAlias_PluginProvider(t *testing.T) { t.Parallel() aliases := map[string][]internalconfig.OAuthModelAlias{ - "qoder": {{Name: "qmodel_latest", Alias: "qlatest"}}, + "sample-provider": {{Name: "sample-model-latest", Alias: "sample-latest"}}, } mgr := NewManager(nil, nil, nil) mgr.SetConfig(&internalconfig.Config{}) mgr.SetOAuthModelAlias(aliases) - auth := &Auth{ID: "qoder-auth", Provider: "qoder", Attributes: map[string]string{"auth_kind": "oauth"}} + auth := &Auth{ID: "sample-provider-auth", Provider: "sample-provider", Attributes: map[string]string{"auth_kind": "oauth"}} - resolvedModel := mgr.applyOAuthModelAlias(auth, "qlatest") - if resolvedModel != "qmodel_latest" { - t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "qmodel_latest") + resolvedModel := mgr.applyOAuthModelAlias(auth, "sample-latest") + if resolvedModel != "sample-model-latest" { + t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "sample-model-latest") } } @@ -225,17 +225,17 @@ func TestApplyOAuthModelAlias_PluginProviderSkipsAPIKey(t *testing.T) { t.Parallel() aliases := map[string][]internalconfig.OAuthModelAlias{ - "qoder": {{Name: "qmodel_latest", Alias: "qlatest"}}, + "sample-provider": {{Name: "sample-model-latest", Alias: "sample-latest"}}, } mgr := NewManager(nil, nil, nil) mgr.SetConfig(&internalconfig.Config{}) mgr.SetOAuthModelAlias(aliases) - auth := &Auth{ID: "qoder-auth", Provider: "qoder", Attributes: map[string]string{"auth_kind": "api_key"}} + auth := &Auth{ID: "sample-provider-auth", Provider: "sample-provider", Attributes: map[string]string{"auth_kind": "api_key"}} - resolvedModel := mgr.applyOAuthModelAlias(auth, "qlatest") - if resolvedModel != "qlatest" { - t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "qlatest") + resolvedModel := mgr.applyOAuthModelAlias(auth, "sample-latest") + if resolvedModel != "sample-latest" { + t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "sample-latest") } } diff --git a/sdk/cliproxy/service_oauth_model_alias_test.go b/sdk/cliproxy/service_oauth_model_alias_test.go index 17990dbc9..c39fbb7b1 100644 --- a/sdk/cliproxy/service_oauth_model_alias_test.go +++ b/sdk/cliproxy/service_oauth_model_alias_test.go @@ -94,41 +94,41 @@ func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) { func TestApplyOAuthModelAlias_PluginProvider(t *testing.T) { cfg := &config.Config{ OAuthModelAlias: map[string][]config.OAuthModelAlias{ - "qoder": { - {Name: "qmodel_latest", Alias: "qlatest"}, + "sample-provider": { + {Name: "sample-model-latest", Alias: "sample-latest"}, }, }, } models := []*ModelInfo{ - {ID: "qmodel_latest", Name: "models/qmodel_latest"}, + {ID: "sample-model-latest", Name: "models/sample-model-latest"}, } - out := applyOAuthModelAlias(cfg, "qoder", "oauth", models) + out := applyOAuthModelAlias(cfg, "sample-provider", "oauth", models) if len(out) != 1 { t.Fatalf("expected 1 model, got %d", len(out)) } - if out[0].ID != "qlatest" { - t.Fatalf("expected plugin alias id %q, got %q", "qlatest", out[0].ID) + if out[0].ID != "sample-latest" { + t.Fatalf("expected plugin alias id %q, got %q", "sample-latest", out[0].ID) } - if out[0].Name != "models/qlatest" { - t.Fatalf("expected plugin alias name %q, got %q", "models/qlatest", out[0].Name) + if out[0].Name != "models/sample-latest" { + t.Fatalf("expected plugin alias name %q, got %q", "models/sample-latest", out[0].Name) } } func TestApplyOAuthModelAlias_PluginProviderSkipsAPIKey(t *testing.T) { cfg := &config.Config{ OAuthModelAlias: map[string][]config.OAuthModelAlias{ - "qoder": { - {Name: "qmodel_latest", Alias: "qlatest"}, + "sample-provider": { + {Name: "sample-model-latest", Alias: "sample-latest"}, }, }, } models := []*ModelInfo{ - {ID: "qmodel_latest", Name: "models/qmodel_latest"}, + {ID: "sample-model-latest", Name: "models/sample-model-latest"}, } - out := applyOAuthModelAlias(cfg, "qoder", "api_key", models) - if len(out) != 1 || out[0].ID != "qmodel_latest" { + out := applyOAuthModelAlias(cfg, "sample-provider", "api_key", models) + if len(out) != 1 || out[0].ID != "sample-model-latest" { t.Fatalf("expected API key plugin model to remain unchanged, got %#v", out) } }