mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-24 14:57:46 +08:00
Merge pull request #3818 from router-for-me/Plugin-stone
feat(pluginstore): add lightweight plugin store installer APIs
This commit is contained in:
@@ -392,9 +392,9 @@ nonstream-keepalive-interval: 0
|
||||
# xai:
|
||||
# - name: "grok-4.3"
|
||||
# alias: "grok-latest"
|
||||
# qoder: # plugin provider keys are supported for OAuth plugins
|
||||
# - name: "qmodel_latest"
|
||||
# alias: "qlatest"
|
||||
# sample-provider: # plugin provider keys are supported for OAuth plugins
|
||||
# - name: "sample-model-latest"
|
||||
# alias: "sample-latest"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginstore"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -35,20 +36,22 @@ const attemptMaxIdleTime = 2 * time.Hour
|
||||
|
||||
// Handler aggregates config reference, persistence path and helpers.
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
configFilePath string
|
||||
mu sync.Mutex
|
||||
attemptsMu sync.Mutex
|
||||
failedAttempts map[string]*attemptInfo // keyed by client IP
|
||||
authManager *coreauth.Manager
|
||||
tokenStore coreauth.Store
|
||||
localPassword string
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
postAuthPersistHook coreauth.PostAuthHook
|
||||
pluginHost *pluginhost.Host
|
||||
cfg *config.Config
|
||||
configFilePath string
|
||||
mu sync.Mutex
|
||||
attemptsMu sync.Mutex
|
||||
failedAttempts map[string]*attemptInfo // keyed by client IP
|
||||
authManager *coreauth.Manager
|
||||
tokenStore coreauth.Store
|
||||
localPassword string
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
postAuthPersistHook coreauth.PostAuthHook
|
||||
pluginHost *pluginhost.Host
|
||||
pluginStoreRegistryURL string
|
||||
pluginStoreHTTPClient pluginstore.HTTPDoer
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
|
||||
286
internal/api/handlers/management/plugin_store.go
Normal file
286
internal/api/handlers/management/plugin_store.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginstore"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
)
|
||||
|
||||
type pluginStoreListResponse struct {
|
||||
PluginsEnabled bool `json:"plugins_enabled"`
|
||||
PluginsDir string `json:"plugins_dir"`
|
||||
Plugins []pluginStoreListEntry `json:"plugins"`
|
||||
}
|
||||
|
||||
type pluginStoreListEntry struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
Version string `json:"version"`
|
||||
Repository string `json:"repository"`
|
||||
Logo string `json:"logo,omitempty"`
|
||||
Homepage string `json:"homepage,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Installed bool `json:"installed"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
Path string `json:"path"`
|
||||
Configured bool `json:"configured"`
|
||||
Registered bool `json:"registered"`
|
||||
Enabled bool `json:"enabled"`
|
||||
EffectiveEnabled bool `json:"effective_enabled"`
|
||||
UpdateAvailable bool `json:"update_available"`
|
||||
}
|
||||
|
||||
type pluginInstallResponse struct {
|
||||
Status string `json:"status"`
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Path string `json:"path"`
|
||||
PluginsEnabled bool `json:"plugins_enabled"`
|
||||
RestartRequired bool `json:"restart_required"`
|
||||
}
|
||||
|
||||
type pluginLocalStatus struct {
|
||||
Installed bool
|
||||
InstalledVersion string
|
||||
Path string
|
||||
Configured bool
|
||||
Registered bool
|
||||
Enabled bool
|
||||
EffectiveEnabled bool
|
||||
}
|
||||
|
||||
func (h *Handler) ListPluginStore(c *gin.Context) {
|
||||
pluginsEnabled, pluginsDir, proxyURL, configs, host := h.pluginStoreSnapshot()
|
||||
client := h.newPluginStoreClient(proxyURL)
|
||||
registry, errRegistry := client.FetchRegistry(c.Request.Context())
|
||||
if errRegistry != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_store_registry_failed", "message": errRegistry.Error()})
|
||||
return
|
||||
}
|
||||
statuses, errStatus := pluginLocalStatuses(pluginsEnabled, pluginsDir, configs, host)
|
||||
if errStatus != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "plugin_discovery_failed", "message": errStatus.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
entries := make([]pluginStoreListEntry, 0, len(registry.Plugins))
|
||||
for _, plugin := range registry.Plugins {
|
||||
status := statuses[plugin.ID]
|
||||
installedVersion := status.InstalledVersion
|
||||
entries = append(entries, pluginStoreListEntry{
|
||||
ID: plugin.ID,
|
||||
Name: plugin.Name,
|
||||
Description: plugin.Description,
|
||||
Author: plugin.Author,
|
||||
Version: plugin.Version,
|
||||
Repository: plugin.Repository,
|
||||
Logo: plugin.Logo,
|
||||
Homepage: plugin.Homepage,
|
||||
License: plugin.License,
|
||||
Tags: append([]string{}, plugin.Tags...),
|
||||
Installed: status.Installed,
|
||||
InstalledVersion: installedVersion,
|
||||
Path: status.Path,
|
||||
Configured: status.Configured,
|
||||
Registered: status.Registered,
|
||||
Enabled: status.Enabled,
|
||||
EffectiveEnabled: status.EffectiveEnabled,
|
||||
UpdateAvailable: pluginstore.UpdateAvailable(installedVersion, plugin.Version),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, pluginStoreListResponse{
|
||||
PluginsEnabled: pluginsEnabled,
|
||||
PluginsDir: pluginsDir,
|
||||
Plugins: entries,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) InstallPluginFromStore(c *gin.Context) {
|
||||
h.installPluginFromStore(c, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
func (h *Handler) installPluginFromStore(c *gin.Context, goos, goarch string) {
|
||||
id, okID := pluginIDFromRequest(c)
|
||||
if !okID {
|
||||
return
|
||||
}
|
||||
pluginsEnabled, pluginsDir, proxyURL, _, host := h.pluginStoreSnapshot()
|
||||
client := h.newPluginStoreClient(proxyURL)
|
||||
registry, errRegistry := client.FetchRegistry(c.Request.Context())
|
||||
if errRegistry != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_store_registry_failed", "message": errRegistry.Error()})
|
||||
return
|
||||
}
|
||||
plugin, okPlugin := registry.PluginByID(id)
|
||||
if !okPlugin {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "plugin_not_found", "message": "plugin not found in registry"})
|
||||
return
|
||||
}
|
||||
|
||||
pluginIsLoaded := func() bool { return pluginLoaded(host, id) }
|
||||
result, errInstall := client.Install(c.Request.Context(), plugin, pluginstore.InstallOptions{
|
||||
PluginsDir: pluginsDir,
|
||||
GOOS: goos,
|
||||
GOARCH: goarch,
|
||||
PluginLoaded: pluginIsLoaded,
|
||||
})
|
||||
if errInstall != nil {
|
||||
if errors.Is(errInstall, pluginstore.ErrLoadedPluginLocked) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "plugin_update_requires_restart",
|
||||
"message": "loaded Windows plugins cannot be overwritten while the server is running",
|
||||
"restart_required": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "plugin_install_failed", "message": errInstall.Error()})
|
||||
return
|
||||
}
|
||||
// Sample after the install so the response reflects the library state at
|
||||
// the time the new file landed on disk.
|
||||
restartRequired := pluginIsLoaded()
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "config_unavailable",
|
||||
"message": fmt.Sprintf("plugin file installed at %s but config is unavailable to enable it", result.Path),
|
||||
"path": result.Path,
|
||||
})
|
||||
return
|
||||
}
|
||||
if errEnable := h.enablePluginConfigLocked(id); errEnable != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "config_update_failed",
|
||||
"message": fmt.Sprintf("plugin file installed at %s but enabling it in config failed: %s", result.Path, errEnable.Error()),
|
||||
"path": result.Path,
|
||||
})
|
||||
return
|
||||
}
|
||||
if errSave := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); errSave != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "config_save_failed",
|
||||
"message": fmt.Sprintf("plugin file installed at %s but saving config failed: %s", result.Path, errSave.Error()),
|
||||
"path": result.Path,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, pluginInstallResponse{
|
||||
Status: "installed",
|
||||
ID: result.ID,
|
||||
Version: result.Version,
|
||||
Path: result.Path,
|
||||
PluginsEnabled: pluginsEnabled,
|
||||
RestartRequired: restartRequired,
|
||||
})
|
||||
}
|
||||
|
||||
// enablePluginConfigLocked sets plugins.configs.<id>.enabled to true while preserving
|
||||
// the rest of the plugin's raw configuration. Callers must hold h.mu.
|
||||
func (h *Handler) enablePluginConfigLocked(id string) error {
|
||||
ensurePluginConfigMap(h.cfg)
|
||||
node := pluginConfigNode(h.cfg.Plugins.Configs[id])
|
||||
setYAMLMappingValue(node, "enabled", boolYAMLNode(true))
|
||||
updated, errConfig := pluginInstanceConfigFromNode(node)
|
||||
if errConfig != nil {
|
||||
return fmt.Errorf("decode plugin config: %w", errConfig)
|
||||
}
|
||||
h.cfg.Plugins.Configs[id] = updated
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) pluginStoreSnapshot() (bool, string, string, map[string]config.PluginInstanceConfig, *pluginhost.Host) {
|
||||
if h == nil || h.cfg == nil {
|
||||
return false, "plugins", "", map[string]config.PluginInstanceConfig{}, nil
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
pluginsEnabled := h.cfg.Plugins.Enabled
|
||||
pluginsDir := normalizedPluginsDir(h.cfg.Plugins.Dir)
|
||||
proxyURL := strings.TrimSpace(h.cfg.ProxyURL)
|
||||
configs := make(map[string]config.PluginInstanceConfig, len(h.cfg.Plugins.Configs))
|
||||
for id, item := range h.cfg.Plugins.Configs {
|
||||
configs[id] = item
|
||||
}
|
||||
return pluginsEnabled, pluginsDir, proxyURL, configs, h.pluginHost
|
||||
}
|
||||
|
||||
func (h *Handler) newPluginStoreClient(proxyURL string) pluginstore.Client {
|
||||
registryURL := ""
|
||||
var httpClient pluginstore.HTTPDoer
|
||||
if h != nil {
|
||||
registryURL = strings.TrimSpace(h.pluginStoreRegistryURL)
|
||||
httpClient = h.pluginStoreHTTPClient
|
||||
}
|
||||
if registryURL == "" {
|
||||
registryURL = pluginstore.DefaultRegistryURL
|
||||
}
|
||||
if httpClient != nil {
|
||||
return pluginstore.Client{HTTPClient: httpClient, RegistryURL: registryURL}
|
||||
}
|
||||
client := &http.Client{}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
util.SetProxy(&sdkconfig.SDKConfig{ProxyURL: strings.TrimSpace(proxyURL)}, client)
|
||||
}
|
||||
return pluginstore.Client{HTTPClient: client, RegistryURL: registryURL}
|
||||
}
|
||||
|
||||
func pluginLocalStatuses(pluginsEnabled bool, pluginsDir string, configs map[string]config.PluginInstanceConfig, host *pluginhost.Host) (map[string]pluginLocalStatus, error) {
|
||||
statuses := map[string]pluginLocalStatus{}
|
||||
files, errDiscover := pluginhost.DiscoverPluginFiles(pluginsDir)
|
||||
if errDiscover != nil {
|
||||
return nil, errDiscover
|
||||
}
|
||||
for _, file := range files {
|
||||
status := statuses[file.ID]
|
||||
status.Installed = true
|
||||
status.Path = file.Path
|
||||
status.Enabled = true
|
||||
statuses[file.ID] = status
|
||||
}
|
||||
for id, item := range configs {
|
||||
status := statuses[id]
|
||||
status.Configured = true
|
||||
status.Enabled = pluginInstanceEnabled(item)
|
||||
statuses[id] = status
|
||||
}
|
||||
if host != nil {
|
||||
for _, info := range host.RegisteredPlugins() {
|
||||
status := statuses[info.ID]
|
||||
status.Installed = true
|
||||
status.Registered = true
|
||||
status.InstalledVersion = strings.TrimSpace(info.Metadata.Version)
|
||||
if _, configured := configs[info.ID]; !configured && !status.Enabled {
|
||||
status.Enabled = true
|
||||
}
|
||||
statuses[info.ID] = status
|
||||
}
|
||||
}
|
||||
for id, status := range statuses {
|
||||
status.EffectiveEnabled = pluginsEnabled && status.Enabled && status.Registered
|
||||
statuses[id] = status
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func pluginLoaded(host *pluginhost.Host, id string) bool {
|
||||
if host == nil {
|
||||
return false
|
||||
}
|
||||
return host.PluginLoaded(id)
|
||||
}
|
||||
258
internal/api/handlers/management/plugin_store_test.go
Normal file
258
internal/api/handlers/management/plugin_store_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
)
|
||||
|
||||
func TestListPluginStoreMergesInstalledStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
pluginsDir := writeManagementPluginFile(t, "sample-provider")
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
Plugins: config.PluginsConfig{
|
||||
Enabled: true,
|
||||
Dir: pluginsDir,
|
||||
Configs: map[string]config.PluginInstanceConfig{
|
||||
"sample-provider": pluginConfigFromYAML(t, "enabled: true\nmode: fast\n"),
|
||||
},
|
||||
},
|
||||
},
|
||||
configFilePath: writeTestConfigFile(t),
|
||||
pluginStoreRegistryURL: "https://registry.example/registry.json",
|
||||
pluginStoreHTTPClient: fakePluginStoreHTTPClient{
|
||||
"https://registry.example/registry.json": registryJSON(t),
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v0/management/plugin-store", nil)
|
||||
|
||||
h.ListPluginStore(c)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
var body pluginStoreListResponse
|
||||
if errDecode := json.Unmarshal(rec.Body.Bytes(), &body); errDecode != nil {
|
||||
t.Fatalf("Unmarshal() error = %v; body=%s", errDecode, rec.Body.String())
|
||||
}
|
||||
if !body.PluginsEnabled {
|
||||
t.Fatal("plugins_enabled = false, want true")
|
||||
}
|
||||
if len(body.Plugins) != 1 {
|
||||
t.Fatalf("plugins len = %d, want 1", len(body.Plugins))
|
||||
}
|
||||
entry := body.Plugins[0]
|
||||
if !entry.Installed || !entry.Configured || !entry.Enabled {
|
||||
t.Fatalf("store entry status = %#v, want installed configured enabled", entry)
|
||||
}
|
||||
if entry.Registered || entry.EffectiveEnabled {
|
||||
t.Fatalf("runtime status = registered %v effective %v, want false false", entry.Registered, entry.EffectiveEnabled)
|
||||
}
|
||||
if entry.InstalledVersion != "" {
|
||||
t.Fatalf("installed_version = %q, want empty for unregistered plugin", entry.InstalledVersion)
|
||||
}
|
||||
if entry.UpdateAvailable {
|
||||
t.Fatal("update_available = true, want false when installed version is unknown")
|
||||
}
|
||||
if entry.Path == "" {
|
||||
t.Fatal("path is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallPluginFromStoreWritesFileAndEnablesConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
pluginsDir := t.TempDir()
|
||||
archiveData := makeManagementPluginStoreZip(t, "sample-provider"+managementPluginExtension(runtime.GOOS), "library-data")
|
||||
archiveName := "sample-provider_0.1.0_" + runtime.GOOS + "_" + runtime.GOARCH + ".zip"
|
||||
checksum := sha256.Sum256(archiveData)
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
Plugins: config.PluginsConfig{
|
||||
Enabled: false,
|
||||
Dir: pluginsDir,
|
||||
Configs: map[string]config.PluginInstanceConfig{
|
||||
"sample-provider": pluginConfigFromYAML(t, "enabled: false\nmode: fast\n"),
|
||||
},
|
||||
},
|
||||
},
|
||||
configFilePath: writeTestConfigFile(t),
|
||||
pluginStoreRegistryURL: "https://registry.example/registry.json",
|
||||
pluginStoreHTTPClient: fakePluginStoreHTTPClient{
|
||||
"https://registry.example/registry.json": registryJSON(t),
|
||||
"https://api.github.com/repos/author-name/cliproxy-sample-provider-plugin/releases/tags/v0.1.0": []byte(`{
|
||||
"tag_name": "v0.1.0",
|
||||
"assets": [
|
||||
{"name": "` + archiveName + `", "browser_download_url": "https://downloads.example/` + archiveName + `"},
|
||||
{"name": "checksums.txt", "browser_download_url": "https://downloads.example/checksums.txt"}
|
||||
]
|
||||
}`),
|
||||
"https://downloads.example/" + archiveName: archiveData,
|
||||
"https://downloads.example/checksums.txt": []byte(hex.EncodeToString(checksum[:]) + " " + archiveName + "\n"),
|
||||
},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = gin.Params{{Key: "id", Value: "sample-provider"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v0/management/plugin-store/sample-provider/install", nil)
|
||||
|
||||
h.InstallPluginFromStore(c)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
var body pluginInstallResponse
|
||||
if errDecode := json.Unmarshal(rec.Body.Bytes(), &body); errDecode != nil {
|
||||
t.Fatalf("Unmarshal() error = %v; body=%s", errDecode, rec.Body.String())
|
||||
}
|
||||
if body.Status != "installed" || body.ID != "sample-provider" || body.Version != "0.1.0" {
|
||||
t.Fatalf("install response = %#v", body)
|
||||
}
|
||||
if body.PluginsEnabled {
|
||||
t.Fatal("plugins_enabled = true, want false")
|
||||
}
|
||||
if body.RestartRequired {
|
||||
t.Fatal("restart_required = true, want false")
|
||||
}
|
||||
targetPath := filepath.Join(pluginsDir, runtime.GOOS, runtime.GOARCH, "sample-provider"+managementPluginExtension(runtime.GOOS))
|
||||
data, errRead := os.ReadFile(targetPath)
|
||||
if errRead != nil {
|
||||
t.Fatalf("ReadFile(%s) error = %v", targetPath, errRead)
|
||||
}
|
||||
if string(data) != "library-data" {
|
||||
t.Fatalf("installed file = %q, want library-data", data)
|
||||
}
|
||||
item := h.cfg.Plugins.Configs["sample-provider"]
|
||||
if item.Enabled == nil || !*item.Enabled {
|
||||
t.Fatalf("plugin enabled = %#v, want true", item.Enabled)
|
||||
}
|
||||
if h.cfg.Plugins.Enabled {
|
||||
t.Fatal("global plugins.enabled changed to true")
|
||||
}
|
||||
raw := marshalPluginRaw(t, item)
|
||||
if !strings.Contains(raw, "mode: fast") {
|
||||
t.Fatalf("plugin raw config lost custom field:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnablePluginConfigLockedPreservesExistingFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
Plugins: config.PluginsConfig{
|
||||
Enabled: false,
|
||||
Configs: map[string]config.PluginInstanceConfig{
|
||||
"sample-provider": pluginConfigFromYAML(t, "enabled: false\npriority: 5\nmode: fast\n"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if errEnable := h.enablePluginConfigLocked("sample-provider"); errEnable != nil {
|
||||
t.Fatalf("enablePluginConfigLocked() error = %v", errEnable)
|
||||
}
|
||||
if h.cfg.Plugins.Enabled {
|
||||
t.Fatal("global Plugins.Enabled changed to true")
|
||||
}
|
||||
item := h.cfg.Plugins.Configs["sample-provider"]
|
||||
if item.Enabled == nil || !*item.Enabled {
|
||||
t.Fatalf("plugin enabled = %#v, want true", item.Enabled)
|
||||
}
|
||||
if item.Priority != 5 {
|
||||
t.Fatalf("plugin priority = %d, want 5", item.Priority)
|
||||
}
|
||||
raw := marshalPluginRaw(t, item)
|
||||
if !strings.Contains(raw, "mode: fast") {
|
||||
t.Fatalf("plugin raw config lost custom field:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnablePluginConfigLockedCreatesMissingConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &Handler{cfg: &config.Config{}}
|
||||
if errEnable := h.enablePluginConfigLocked("sample-provider"); errEnable != nil {
|
||||
t.Fatalf("enablePluginConfigLocked() error = %v", errEnable)
|
||||
}
|
||||
item := h.cfg.Plugins.Configs["sample-provider"]
|
||||
if item.Enabled == nil || !*item.Enabled {
|
||||
t.Fatalf("plugin enabled = %#v, want true", item.Enabled)
|
||||
}
|
||||
}
|
||||
|
||||
type fakePluginStoreHTTPClient map[string][]byte
|
||||
|
||||
func (c fakePluginStoreHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
body, ok := c[req.URL.String()]
|
||||
if !ok {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: io.NopCloser(strings.NewReader("not found")),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func registryJSON(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
return []byte(`{
|
||||
"schema_version": 1,
|
||||
"plugins": [{
|
||||
"id": "sample-provider",
|
||||
"name": "Sample Provider",
|
||||
"description": "Adds sample provider support.",
|
||||
"author": "author-name",
|
||||
"version": "0.1.0",
|
||||
"repository": "https://github.com/author-name/cliproxy-sample-provider-plugin",
|
||||
"tags": ["provider"]
|
||||
}]
|
||||
}`)
|
||||
}
|
||||
|
||||
func makeManagementPluginStoreZip(t *testing.T, name string, content string) []byte {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := zip.NewWriter(&buffer)
|
||||
file, errCreate := writer.Create(name)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("Create(%s) error = %v", name, errCreate)
|
||||
}
|
||||
if _, errWrite := file.Write([]byte(content)); errWrite != nil {
|
||||
t.Fatalf("Write(%s) error = %v", name, errWrite)
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("Close() error = %v", errClose)
|
||||
}
|
||||
return buffer.Bytes()
|
||||
}
|
||||
@@ -603,6 +603,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
||||
mgmt.GET("/plugins", s.mgmt.ListPlugins)
|
||||
mgmt.GET("/plugin-store", s.mgmt.ListPluginStore)
|
||||
mgmt.POST("/plugin-store/:id/install", s.mgmt.InstallPluginFromStore)
|
||||
mgmt.PATCH("/plugins/:id/enabled", s.mgmt.PatchPluginEnabled)
|
||||
mgmt.PUT("/plugins/:id/config", s.mgmt.PutPluginConfig)
|
||||
mgmt.PATCH("/plugins/:id/config", s.mgmt.PatchPluginConfig)
|
||||
|
||||
62
internal/httpfetch/httpfetch.go
Normal file
62
internal/httpfetch/httpfetch.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package httpfetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Doer abstracts the HTTP client used to execute requests.
|
||||
type Doer interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// GetBytes performs a GET request with the supplied headers, requires a
|
||||
// success status, and returns the response body. When maxSize is positive
|
||||
// the body is rejected once it exceeds maxSize bytes.
|
||||
func GetBytes(ctx context.Context, client Doer, requestURL string, headers map[string]string, maxSize int64) ([]byte, error) {
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if errRequest != nil {
|
||||
return nil, fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
for key, value := range headers {
|
||||
if value != "" {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
resp, errDo := client.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.WithError(errClose).Debug("failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
reader := io.Reader(resp.Body)
|
||||
if maxSize > 0 {
|
||||
reader = io.LimitReader(resp.Body, maxSize+1)
|
||||
}
|
||||
data, errRead := io.ReadAll(reader)
|
||||
if errRead != nil {
|
||||
return nil, fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
if maxSize > 0 && int64(len(data)) > maxSize {
|
||||
return nil, fmt.Errorf("response exceeds maximum allowed size of %d bytes", maxSize)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
67
internal/httpfetch/httpfetch_test.go
Normal file
67
internal/httpfetch/httpfetch_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package httpfetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBytesReturnsBodyAndSendsHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("User-Agent") != "agent" || r.Header.Get("Accept") != "application/json" {
|
||||
http.Error(w, "missing headers", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte("payload"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
data, errGet := GetBytes(context.Background(), server.Client(), server.URL, map[string]string{
|
||||
"User-Agent": "agent",
|
||||
"Accept": "application/json",
|
||||
}, 0)
|
||||
if errGet != nil {
|
||||
t.Fatalf("GetBytes() error = %v", errGet)
|
||||
}
|
||||
if string(data) != "payload" {
|
||||
t.Fatalf("GetBytes() = %q, want payload", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBytesRejectsErrorStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "missing", http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
_, errGet := GetBytes(context.Background(), server.Client(), server.URL, nil, 0)
|
||||
if errGet == nil {
|
||||
t.Fatal("GetBytes() error = nil")
|
||||
}
|
||||
if !strings.Contains(errGet.Error(), "unexpected status 404") {
|
||||
t.Fatalf("GetBytes() error = %v, want status 404", errGet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBytesEnforcesMaxSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("0123456789"))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
_, errGet := GetBytes(context.Background(), server.Client(), server.URL, nil, 4)
|
||||
if errGet == nil {
|
||||
t.Fatal("GetBytes() error = nil")
|
||||
}
|
||||
if !strings.Contains(errGet.Error(), "maximum allowed size") {
|
||||
t.Fatalf("GetBytes() error = %v, want size limit error", errGet)
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/httpfetch"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -345,32 +346,22 @@ func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL strin
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
headers := map[string]string{
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": httpUserAgent,
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", httpUserAgent)
|
||||
gitURL := strings.ToLower(strings.TrimSpace(os.Getenv("GITSTORE_GIT_URL")))
|
||||
if tok := strings.TrimSpace(os.Getenv("GITSTORE_GIT_TOKEN")); tok != "" && strings.Contains(gitURL, "github.com") {
|
||||
req.Header.Set("Authorization", "Bearer "+tok)
|
||||
headers["Authorization"] = "Bearer " + tok
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
data, err := httpfetch.GetBytes(ctx, client, releaseURL, headers, 0)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("execute release request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, "", fmt.Errorf("unexpected release status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return nil, "", fmt.Errorf("fetch release: %w", err)
|
||||
}
|
||||
|
||||
var release releaseResponse
|
||||
if err = json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
if err = json.Unmarshal(data, &release); err != nil {
|
||||
return nil, "", fmt.Errorf("decode release response: %w", err)
|
||||
}
|
||||
|
||||
@@ -390,31 +381,9 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
|
||||
return nil, "", fmt.Errorf("empty download url")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
|
||||
data, err := httpfetch.GetBytes(ctx, client, downloadURL, map[string]string{"User-Agent": httpUserAgent}, maxAssetDownloadSize)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", httpUserAgent)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("execute download request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("read download body: %w", err)
|
||||
}
|
||||
if int64(len(data)) > maxAssetDownloadSize {
|
||||
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
|
||||
return nil, "", fmt.Errorf("download asset: %w", err)
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(data)
|
||||
|
||||
@@ -623,25 +623,25 @@ func TestRegisterExecutorsOAuthScopeSkipsStaticModelClientButRegistersExecutor(t
|
||||
manager := newFakeExecutorManager()
|
||||
staticCalled := false
|
||||
host := newHostWithRecords(capabilityRecord{
|
||||
id: "qoder",
|
||||
id: "sample-provider",
|
||||
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
|
||||
AuthProvider: fakeAuthProvider{identifier: "qoder"},
|
||||
AuthProvider: fakeAuthProvider{identifier: "sample-provider"},
|
||||
ModelProvider: modelProviderFunc{
|
||||
staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
|
||||
staticCalled = true
|
||||
return pluginapi.ModelResponse{
|
||||
Provider: "qoder",
|
||||
Provider: "sample-provider",
|
||||
Models: []pluginapi.ModelInfo{{ID: "static-model"}},
|
||||
}, nil
|
||||
},
|
||||
modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
|
||||
return pluginapi.ModelResponse{
|
||||
Provider: "qoder",
|
||||
Provider: "sample-provider",
|
||||
Models: []pluginapi.ModelInfo{{ID: "oauth-model"}},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
Executor: &fakeExecutor{identifier: "qoder"},
|
||||
Executor: &fakeExecutor{identifier: "sample-provider"},
|
||||
ExecutorModelScope: pluginapi.ExecutorModelScopeOAuth,
|
||||
}},
|
||||
})
|
||||
@@ -652,21 +652,21 @@ func TestRegisterExecutorsOAuthScopeSkipsStaticModelClientButRegistersExecutor(t
|
||||
if staticCalled {
|
||||
t.Fatal("StaticModels was called for an OAuth-only executor")
|
||||
}
|
||||
if _, okExecutor := manager.executors["qoder"]; !okExecutor {
|
||||
if _, okExecutor := manager.executors["sample-provider"]; !okExecutor {
|
||||
t.Fatal("OAuth-only executor was not registered")
|
||||
}
|
||||
if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("qoder", "qoder")]; okClient {
|
||||
if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("sample-provider", "sample-provider")]; okClient {
|
||||
t.Fatal("OAuth-only executor registered a static model client")
|
||||
}
|
||||
if got := host.ModelsForProvider("qoder"); len(got) != 0 {
|
||||
if got := host.ModelsForProvider("sample-provider"); len(got) != 0 {
|
||||
t.Fatalf("OAuth-only provider models = %#v, want none", got)
|
||||
}
|
||||
|
||||
result := host.ModelsForAuth(context.Background(), &coreauth.Auth{
|
||||
ID: "qoder-auth",
|
||||
Provider: "qoder",
|
||||
ID: "sample-provider-auth",
|
||||
Provider: "sample-provider",
|
||||
})
|
||||
if !result.Handled || result.Provider != "qoder" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" {
|
||||
if !result.Handled || result.Provider != "sample-provider" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" {
|
||||
t.Fatalf("OAuth model result = %#v, want oauth-model", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,16 +127,16 @@ func TestExecuteCommandLinePersistsReturnedAuths(t *testing.T) {
|
||||
response: pluginapi.CommandLineExecutionResponse{
|
||||
Stdout: []byte("login ok\n"),
|
||||
Auths: []pluginapi.AuthData{{
|
||||
Provider: "Qoder",
|
||||
ID: "qoder.json",
|
||||
FileName: "qoder.json",
|
||||
Provider: "Sample-Provider",
|
||||
ID: "sample-provider.json",
|
||||
FileName: "sample-provider.json",
|
||||
Label: "Luis",
|
||||
StorageJSON: []byte(`{"token":"secret"}`),
|
||||
}},
|
||||
},
|
||||
}
|
||||
host := newHostWithRecords(capabilityRecord{
|
||||
id: "qoder",
|
||||
id: "sample-provider",
|
||||
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{CommandLinePlugin: plugin}},
|
||||
})
|
||||
host.runtimeConfig = &config.Config{AuthDir: authDir}
|
||||
@@ -160,13 +160,13 @@ func TestExecuteCommandLinePersistsReturnedAuths(t *testing.T) {
|
||||
t.Fatalf("saved auths = %d, want 1", len(store.saved))
|
||||
}
|
||||
saved := store.saved[0]
|
||||
if saved.Provider != "qoder" || saved.ID != "qoder.json" || saved.FileName != "qoder.json" {
|
||||
t.Fatalf("saved auth = %#v, want normalized qoder auth", saved)
|
||||
if saved.Provider != "sample-provider" || saved.ID != "sample-provider.json" || saved.FileName != "sample-provider.json" {
|
||||
t.Fatalf("saved auth = %#v, want normalized sample provider auth", saved)
|
||||
}
|
||||
if saved.Storage == nil {
|
||||
t.Fatal("saved auth storage = nil, want plugin token storage")
|
||||
}
|
||||
if store.paths[0] != filepath.Join(authDir, "qoder.json") {
|
||||
if store.paths[0] != filepath.Join(authDir, "sample-provider.json") {
|
||||
t.Fatalf("saved path = %q, want auth dir path", store.paths[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,21 @@ func (h *Host) Snapshot() *Snapshot {
|
||||
return emptySnapshot()
|
||||
}
|
||||
|
||||
// PluginLoaded reports whether a plugin dynamic library is still loaded by the host.
|
||||
func (h *Host) PluginLoaded(id string) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return false
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
_, ok := h.loaded[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (h *Host) ApplyConfig(ctx context.Context, cfg *config.Config) {
|
||||
if h == nil {
|
||||
return
|
||||
|
||||
@@ -64,6 +64,55 @@ func TestHostApplyConfig_DisabledPluginSkipsCapability(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginLoadedTracksLoadedPluginAfterDisabled(t *testing.T) {
|
||||
disabled := false
|
||||
loader := newTestSymbolLoader()
|
||||
plugin := &testPlugin{
|
||||
registerResult: validTestPlugin("alpha"),
|
||||
reconfigureResult: validTestPlugin("alpha"),
|
||||
}
|
||||
loader.lookups["alpha"] = newTestSymbolLookup(plugin)
|
||||
h := NewForTest(loader)
|
||||
t.Cleanup(h.ShutdownAll)
|
||||
pluginsDir := makePluginDir(t, "alpha")
|
||||
|
||||
h.ApplyConfig(context.Background(), &config.Config{
|
||||
Plugins: config.PluginsConfig{
|
||||
Enabled: true,
|
||||
Dir: pluginsDir,
|
||||
},
|
||||
})
|
||||
|
||||
if !h.PluginLoaded("alpha") {
|
||||
t.Fatal("PluginLoaded(alpha) = false, want true after load")
|
||||
}
|
||||
if len(h.RegisteredPlugins()) != 1 {
|
||||
t.Fatalf("RegisteredPlugins() len = %d, want 1", len(h.RegisteredPlugins()))
|
||||
}
|
||||
|
||||
h.ApplyConfig(context.Background(), &config.Config{
|
||||
Plugins: config.PluginsConfig{
|
||||
Enabled: true,
|
||||
Dir: pluginsDir,
|
||||
Configs: map[string]config.PluginInstanceConfig{
|
||||
"alpha": {Enabled: &disabled},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if len(h.RegisteredPlugins()) != 0 {
|
||||
t.Fatalf("RegisteredPlugins() len = %d, want 0 after disable", len(h.RegisteredPlugins()))
|
||||
}
|
||||
if !h.PluginLoaded("alpha") {
|
||||
t.Fatal("PluginLoaded(alpha) = false, want true while library remains loaded")
|
||||
}
|
||||
|
||||
h.ShutdownAll()
|
||||
if h.PluginLoaded("alpha") {
|
||||
t.Fatal("PluginLoaded(alpha) = true, want false after ShutdownAll")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostApplyConfigRegistersPluginThinkingApplier(t *testing.T) {
|
||||
loader := newTestSymbolLoader()
|
||||
plugin := &testPlugin{
|
||||
|
||||
@@ -44,6 +44,11 @@ func pluginIDFromPath(path string) string {
|
||||
return base
|
||||
}
|
||||
|
||||
// PluginExtension returns the dynamic library file extension used for goos.
|
||||
func PluginExtension(goos string) string {
|
||||
return pluginExtension(goos)
|
||||
}
|
||||
|
||||
func pluginExtension(goos string) string {
|
||||
switch goos {
|
||||
case "darwin":
|
||||
|
||||
45
internal/pluginstore/checksum.go
Normal file
45
internal/pluginstore/checksum.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ParseChecksums(data []byte) (map[string]string, error) {
|
||||
out := map[string]string{}
|
||||
for lineNumber, rawLine := range strings.Split(string(data), "\n") {
|
||||
line := strings.TrimSpace(rawLine)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("line %d: invalid checksum entry", lineNumber+1)
|
||||
}
|
||||
hash := strings.ToLower(strings.TrimSpace(fields[0]))
|
||||
if len(hash) != sha256.Size*2 {
|
||||
return nil, fmt.Errorf("line %d: invalid sha256 length", lineNumber+1)
|
||||
}
|
||||
if _, errDecode := hex.DecodeString(hash); errDecode != nil {
|
||||
return nil, fmt.Errorf("line %d: invalid sha256: %w", lineNumber+1, errDecode)
|
||||
}
|
||||
name := strings.TrimPrefix(strings.TrimSpace(fields[1]), "*")
|
||||
out[name] = hash
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func VerifyChecksum(name string, data []byte, checksums map[string]string) error {
|
||||
expected := strings.ToLower(strings.TrimSpace(checksums[name]))
|
||||
if expected == "" {
|
||||
return fmt.Errorf("checksum for %s not found", name)
|
||||
}
|
||||
actualBytes := sha256.Sum256(data)
|
||||
actual := hex.EncodeToString(actualBytes[:])
|
||||
if actual != expected {
|
||||
return fmt.Errorf("checksum mismatch for %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
130
internal/pluginstore/github.go
Normal file
130
internal/pluginstore/github.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/httpfetch"
|
||||
)
|
||||
|
||||
const userAgent = "CLIProxyAPI"
|
||||
|
||||
// HTTPDoer abstracts the HTTP client used to execute requests.
|
||||
type HTTPDoer = httpfetch.Doer
|
||||
|
||||
type Client struct {
|
||||
HTTPClient HTTPDoer
|
||||
RegistryURL string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
type Release struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Assets []ReleaseAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type ReleaseAsset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
}
|
||||
|
||||
func (c Client) FetchRegistry(ctx context.Context) (Registry, error) {
|
||||
registryURL := strings.TrimSpace(c.RegistryURL)
|
||||
if registryURL == "" {
|
||||
registryURL = DefaultRegistryURL
|
||||
}
|
||||
data, errDownload := c.get(ctx, registryURL, "application/json")
|
||||
if errDownload != nil {
|
||||
return Registry{}, errDownload
|
||||
}
|
||||
registry, errParse := ParseRegistry(data)
|
||||
if errParse != nil {
|
||||
return Registry{}, errParse
|
||||
}
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
func (c Client) FetchRelease(ctx context.Context, plugin Plugin) (Release, error) {
|
||||
owner, repo, errRepository := GitHubRepositoryParts(plugin.Repository)
|
||||
if errRepository != nil {
|
||||
return Release{}, errRepository
|
||||
}
|
||||
releaseURL := fmt.Sprintf(
|
||||
"https://api.github.com/repos/%s/%s/releases/tags/%s",
|
||||
url.PathEscape(owner),
|
||||
url.PathEscape(repo),
|
||||
url.PathEscape("v"+strings.TrimSpace(plugin.Version)),
|
||||
)
|
||||
data, errDownload := c.get(ctx, releaseURL, "application/vnd.github+json")
|
||||
if errDownload != nil {
|
||||
return Release{}, errDownload
|
||||
}
|
||||
var release Release
|
||||
if errDecode := json.Unmarshal(data, &release); errDecode != nil {
|
||||
return Release{}, fmt.Errorf("decode release: %w", errDecode)
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func (c Client) DownloadAsset(ctx context.Context, asset ReleaseAsset) ([]byte, error) {
|
||||
if strings.TrimSpace(asset.BrowserDownloadURL) == "" {
|
||||
return nil, fmt.Errorf("asset %q missing browser_download_url", asset.Name)
|
||||
}
|
||||
return c.get(ctx, asset.BrowserDownloadURL, "application/octet-stream")
|
||||
}
|
||||
|
||||
func (c Client) get(ctx context.Context, requestURL string, accept string) ([]byte, error) {
|
||||
return httpfetch.GetBytes(ctx, c.httpClient(), requestURL, map[string]string{
|
||||
"Accept": accept,
|
||||
"User-Agent": c.userAgent(),
|
||||
}, 0)
|
||||
}
|
||||
|
||||
func (c Client) httpClient() HTTPDoer {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func (c Client) userAgent() string {
|
||||
if strings.TrimSpace(c.UserAgent) != "" {
|
||||
return strings.TrimSpace(c.UserAgent)
|
||||
}
|
||||
return userAgent
|
||||
}
|
||||
|
||||
func SelectReleaseAssets(release Release, id, version, goos, goarch string) (ReleaseAsset, ReleaseAsset, error) {
|
||||
archiveName := ArchiveName(id, version, goos, goarch)
|
||||
var archiveAsset ReleaseAsset
|
||||
var checksumAsset ReleaseAsset
|
||||
for _, asset := range release.Assets {
|
||||
switch strings.TrimSpace(asset.Name) {
|
||||
case archiveName:
|
||||
archiveAsset = asset
|
||||
case "checksums.txt":
|
||||
checksumAsset = asset
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(archiveAsset.Name) == "" {
|
||||
return ReleaseAsset{}, ReleaseAsset{}, fmt.Errorf("release asset %s not found", archiveName)
|
||||
}
|
||||
if strings.TrimSpace(checksumAsset.Name) == "" {
|
||||
return ReleaseAsset{}, ReleaseAsset{}, fmt.Errorf("release asset checksums.txt not found")
|
||||
}
|
||||
return archiveAsset, checksumAsset, nil
|
||||
}
|
||||
|
||||
func ArchiveName(id, version, goos, goarch string) string {
|
||||
return fmt.Sprintf(
|
||||
"%s_%s_%s_%s.zip",
|
||||
strings.TrimSpace(id),
|
||||
strings.TrimSpace(version),
|
||||
strings.TrimSpace(goos),
|
||||
strings.TrimSpace(goarch),
|
||||
)
|
||||
}
|
||||
93
internal/pluginstore/github_test.go
Normal file
93
internal/pluginstore/github_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSelectReleaseAssets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
release := Release{Assets: []ReleaseAsset{
|
||||
{Name: "sample-provider_0.1.0_darwin_arm64.zip", BrowserDownloadURL: "https://example.com/sample-provider.zip"},
|
||||
{Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
|
||||
}}
|
||||
archiveAsset, checksumAsset, errSelect := SelectReleaseAssets(release, "sample-provider", "0.1.0", "darwin", "arm64")
|
||||
if errSelect != nil {
|
||||
t.Fatalf("SelectReleaseAssets() error = %v", errSelect)
|
||||
}
|
||||
if archiveAsset.BrowserDownloadURL != "https://example.com/sample-provider.zip" {
|
||||
t.Fatalf("archive URL = %q", archiveAsset.BrowserDownloadURL)
|
||||
}
|
||||
if checksumAsset.BrowserDownloadURL != "https://example.com/checksums.txt" {
|
||||
t.Fatalf("checksum URL = %q", checksumAsset.BrowserDownloadURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectReleaseAssetsRejectsMissingAssets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
release Release
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "missing zip",
|
||||
release: Release{Assets: []ReleaseAsset{
|
||||
{Name: "checksums.txt", BrowserDownloadURL: "https://example.com/checksums.txt"},
|
||||
}},
|
||||
wantErr: "sample-provider_0.1.0_darwin_arm64.zip",
|
||||
},
|
||||
{
|
||||
name: "missing checksum",
|
||||
release: Release{Assets: []ReleaseAsset{
|
||||
{Name: "sample-provider_0.1.0_darwin_arm64.zip", BrowserDownloadURL: "https://example.com/sample-provider.zip"},
|
||||
}},
|
||||
wantErr: "checksums.txt",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, errSelect := SelectReleaseAssets(tt.release, "sample-provider", "0.1.0", "darwin", "arm64")
|
||||
if errSelect == nil {
|
||||
t.Fatal("SelectReleaseAssets() error = nil")
|
||||
}
|
||||
if !strings.Contains(errSelect.Error(), tt.wantErr) {
|
||||
t.Fatalf("SelectReleaseAssets() error = %v, want substring %q", errSelect, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseChecksumsAndVerifyChecksum(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := []byte("zip-data")
|
||||
sum := sha256.Sum256(data)
|
||||
checksumText := hex.EncodeToString(sum[:]) + " sample-provider_0.1.0_darwin_arm64.zip\n"
|
||||
checksums, errParse := ParseChecksums([]byte(checksumText))
|
||||
if errParse != nil {
|
||||
t.Fatalf("ParseChecksums() error = %v", errParse)
|
||||
}
|
||||
if errVerify := VerifyChecksum("sample-provider_0.1.0_darwin_arm64.zip", data, checksums); errVerify != nil {
|
||||
t.Fatalf("VerifyChecksum() error = %v", errVerify)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyChecksumRejectsMissingAndMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sum := sha256.Sum256([]byte("zip-data"))
|
||||
checksums := map[string]string{"sample-provider.zip": hex.EncodeToString(sum[:])}
|
||||
if errVerify := VerifyChecksum("missing.zip", []byte("zip-data"), checksums); errVerify == nil {
|
||||
t.Fatal("VerifyChecksum() missing checksum error = nil")
|
||||
}
|
||||
if errVerify := VerifyChecksum("sample-provider.zip", []byte("other"), checksums); errVerify == nil {
|
||||
t.Fatal("VerifyChecksum() mismatch error = nil")
|
||||
}
|
||||
}
|
||||
277
internal/pluginstore/install.go
Normal file
277
internal/pluginstore/install.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type InstallOptions struct {
|
||||
PluginsDir string
|
||||
GOOS string
|
||||
GOARCH string
|
||||
// PluginLoaded reports whether the plugin's dynamic library is currently
|
||||
// loaded by the running host. Loaded libraries cannot be overwritten on
|
||||
// Windows, so installs targeting Windows are rejected while it returns true.
|
||||
PluginLoaded func() bool
|
||||
}
|
||||
|
||||
// ErrLoadedPluginLocked is returned when an install would overwrite a plugin
|
||||
// library that is loaded by the running process on Windows.
|
||||
var ErrLoadedPluginLocked = errors.New("loaded plugin library cannot be overwritten while the server is running")
|
||||
|
||||
type InstallResult struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Path string `json:"path"`
|
||||
Overwritten bool `json:"overwritten"`
|
||||
}
|
||||
|
||||
func (c Client) Install(ctx context.Context, plugin Plugin, options InstallOptions) (InstallResult, error) {
|
||||
if errValidate := ValidatePlugin(plugin); errValidate != nil {
|
||||
return InstallResult{}, errValidate
|
||||
}
|
||||
options = normalizeInstallOptions(options)
|
||||
if loadedPluginInstallBlocked(options) {
|
||||
return InstallResult{}, ErrLoadedPluginLocked
|
||||
}
|
||||
release, errRelease := c.FetchRelease(ctx, plugin)
|
||||
if errRelease != nil {
|
||||
return InstallResult{}, errRelease
|
||||
}
|
||||
archiveAsset, checksumAsset, errAssets := SelectReleaseAssets(release, plugin.ID, plugin.Version, options.GOOS, options.GOARCH)
|
||||
if errAssets != nil {
|
||||
return InstallResult{}, errAssets
|
||||
}
|
||||
archiveData, errArchive := c.DownloadAsset(ctx, archiveAsset)
|
||||
if errArchive != nil {
|
||||
return InstallResult{}, fmt.Errorf("download %s: %w", archiveAsset.Name, errArchive)
|
||||
}
|
||||
checksumData, errChecksum := c.DownloadAsset(ctx, checksumAsset)
|
||||
if errChecksum != nil {
|
||||
return InstallResult{}, fmt.Errorf("download checksums.txt: %w", errChecksum)
|
||||
}
|
||||
checksums, errParse := ParseChecksums(checksumData)
|
||||
if errParse != nil {
|
||||
return InstallResult{}, errParse
|
||||
}
|
||||
if errVerify := VerifyChecksum(archiveAsset.Name, archiveData, checksums); errVerify != nil {
|
||||
return InstallResult{}, errVerify
|
||||
}
|
||||
return InstallArchive(archiveData, plugin, options)
|
||||
}
|
||||
|
||||
func InstallArchive(archiveData []byte, plugin Plugin, options InstallOptions) (InstallResult, error) {
|
||||
options = normalizeInstallOptions(options)
|
||||
id := strings.TrimSpace(plugin.ID)
|
||||
if !pluginhost.ValidatePluginID(id) {
|
||||
return InstallResult{}, fmt.Errorf("invalid plugin id %q", plugin.ID)
|
||||
}
|
||||
reader, errZip := zip.NewReader(bytes.NewReader(archiveData), int64(len(archiveData)))
|
||||
if errZip != nil {
|
||||
return InstallResult{}, fmt.Errorf("open zip: %w", errZip)
|
||||
}
|
||||
|
||||
libraryData, mode, errLibrary := readTargetLibrary(reader, id, options.GOOS)
|
||||
if errLibrary != nil {
|
||||
return InstallResult{}, errLibrary
|
||||
}
|
||||
|
||||
targetPath, errTarget := installTargetPath(options, id)
|
||||
if errTarget != nil {
|
||||
return InstallResult{}, errTarget
|
||||
}
|
||||
overwritten := false
|
||||
if _, errStat := os.Stat(targetPath); errStat == nil {
|
||||
overwritten = true
|
||||
} else if !errors.Is(errStat, os.ErrNotExist) {
|
||||
return InstallResult{}, fmt.Errorf("stat target plugin: %w", errStat)
|
||||
}
|
||||
// Re-check immediately before writing: the plugin may have been loaded
|
||||
// while the archive was being downloaded and verified.
|
||||
if loadedPluginInstallBlocked(options) {
|
||||
return InstallResult{}, ErrLoadedPluginLocked
|
||||
}
|
||||
if errWrite := writeFileAtomic(targetPath, libraryData, mode); errWrite != nil {
|
||||
return InstallResult{}, errWrite
|
||||
}
|
||||
return InstallResult{
|
||||
ID: id,
|
||||
Version: strings.TrimSpace(plugin.Version),
|
||||
Path: targetPath,
|
||||
Overwritten: overwritten,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func installTargetPath(options InstallOptions, id string) (string, error) {
|
||||
defaultPath := filepath.Join(options.PluginsDir, options.GOOS, options.GOARCH, id+pluginhost.PluginExtension(options.GOOS))
|
||||
if options.GOOS != runtime.GOOS || options.GOARCH != runtime.GOARCH {
|
||||
return defaultPath, nil
|
||||
}
|
||||
files, errDiscover := pluginhost.DiscoverPluginFiles(options.PluginsDir)
|
||||
if errDiscover != nil {
|
||||
return "", fmt.Errorf("discover current plugin files: %w", errDiscover)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file.ID == id && strings.TrimSpace(file.Path) != "" {
|
||||
return file.Path, nil
|
||||
}
|
||||
}
|
||||
return defaultPath, nil
|
||||
}
|
||||
|
||||
func readTargetLibrary(reader *zip.Reader, id string, goos string) ([]byte, os.FileMode, error) {
|
||||
targetName := strings.TrimSpace(id) + pluginhost.PluginExtension(goos)
|
||||
var target *zip.File
|
||||
for _, file := range reader.File {
|
||||
cleanedName, errClean := cleanZipName(file.Name)
|
||||
if errClean != nil {
|
||||
return nil, 0, errClean
|
||||
}
|
||||
if file.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
if !regularZipFile(file) {
|
||||
return nil, 0, fmt.Errorf("zip entry %s is not a regular file", file.Name)
|
||||
}
|
||||
if !hasDynamicLibraryExtension(cleanedName) {
|
||||
continue
|
||||
}
|
||||
if cleanedName != targetName {
|
||||
if path.Base(cleanedName) == targetName {
|
||||
return nil, 0, fmt.Errorf("target dynamic library must be at zip root")
|
||||
}
|
||||
return nil, 0, fmt.Errorf("dynamic library filename must be %s", targetName)
|
||||
}
|
||||
if target != nil {
|
||||
return nil, 0, fmt.Errorf("zip contains multiple target dynamic libraries")
|
||||
}
|
||||
target = file
|
||||
}
|
||||
if target == nil {
|
||||
return nil, 0, fmt.Errorf("zip does not contain %s", targetName)
|
||||
}
|
||||
|
||||
handle, errOpen := target.Open()
|
||||
if errOpen != nil {
|
||||
return nil, 0, fmt.Errorf("open %s: %w", targetName, errOpen)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := handle.Close(); errClose != nil {
|
||||
log.WithError(errClose).Debug("failed to close plugin archive entry")
|
||||
}
|
||||
}()
|
||||
data, errRead := io.ReadAll(handle)
|
||||
if errRead != nil {
|
||||
return nil, 0, fmt.Errorf("read %s: %w", targetName, errRead)
|
||||
}
|
||||
mode := target.FileInfo().Mode().Perm()
|
||||
if mode == 0 {
|
||||
mode = 0o755
|
||||
}
|
||||
return data, mode, nil
|
||||
}
|
||||
|
||||
func cleanZipName(name string) (string, error) {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return "", fmt.Errorf("zip entry has empty name")
|
||||
}
|
||||
if strings.Contains(name, `\`) {
|
||||
return "", fmt.Errorf("zip entry %s uses backslash path separators", name)
|
||||
}
|
||||
if path.IsAbs(name) {
|
||||
return "", fmt.Errorf("zip entry %s is absolute", name)
|
||||
}
|
||||
cleaned := path.Clean(name)
|
||||
if cleaned == "." || cleaned == ".." || strings.HasPrefix(cleaned, "../") {
|
||||
return "", fmt.Errorf("zip entry %s escapes archive root", name)
|
||||
}
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func regularZipFile(file *zip.File) bool {
|
||||
mode := file.FileInfo().Mode()
|
||||
return mode.IsRegular() || mode.Type() == 0
|
||||
}
|
||||
|
||||
func hasDynamicLibraryExtension(name string) bool {
|
||||
lowerName := strings.ToLower(name)
|
||||
return strings.HasSuffix(lowerName, ".dylib") || strings.HasSuffix(lowerName, ".so") || strings.HasSuffix(lowerName, ".dll")
|
||||
}
|
||||
|
||||
func writeFileAtomic(targetPath string, data []byte, mode os.FileMode) error {
|
||||
targetDir := filepath.Dir(targetPath)
|
||||
if errMkdir := os.MkdirAll(targetDir, 0o755); errMkdir != nil {
|
||||
return fmt.Errorf("create plugin directory: %w", errMkdir)
|
||||
}
|
||||
|
||||
temp, errTemp := os.CreateTemp(targetDir, "."+filepath.Base(targetPath)+".tmp-*")
|
||||
if errTemp != nil {
|
||||
return fmt.Errorf("create temp plugin file: %w", errTemp)
|
||||
}
|
||||
tempPath := temp.Name()
|
||||
removeTemp := true
|
||||
closed := false
|
||||
defer func() {
|
||||
if !closed {
|
||||
if errClose := temp.Close(); errClose != nil {
|
||||
log.WithError(errClose).Debug("failed to close temp plugin file")
|
||||
}
|
||||
}
|
||||
if removeTemp {
|
||||
if errRemove := os.Remove(tempPath); errRemove != nil && !errors.Is(errRemove, os.ErrNotExist) {
|
||||
log.WithError(errRemove).Debug("failed to remove temp plugin file")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if errChmod := temp.Chmod(mode); errChmod != nil {
|
||||
return fmt.Errorf("chmod temp plugin file: %w", errChmod)
|
||||
}
|
||||
if _, errWrite := temp.Write(data); errWrite != nil {
|
||||
return fmt.Errorf("write temp plugin file: %w", errWrite)
|
||||
}
|
||||
if errSync := temp.Sync(); errSync != nil {
|
||||
return fmt.Errorf("sync temp plugin file: %w", errSync)
|
||||
}
|
||||
if errClose := temp.Close(); errClose != nil {
|
||||
return fmt.Errorf("close temp plugin file: %w", errClose)
|
||||
}
|
||||
closed = true
|
||||
if errRename := os.Rename(tempPath, targetPath); errRename != nil {
|
||||
return fmt.Errorf("install plugin file: %w", errRename)
|
||||
}
|
||||
removeTemp = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadedPluginInstallBlocked(options InstallOptions) bool {
|
||||
return options.PluginLoaded != nil && strings.EqualFold(options.GOOS, "windows") && options.PluginLoaded()
|
||||
}
|
||||
|
||||
func normalizeInstallOptions(options InstallOptions) InstallOptions {
|
||||
options.PluginsDir = strings.TrimSpace(options.PluginsDir)
|
||||
if options.PluginsDir == "" {
|
||||
options.PluginsDir = "plugins"
|
||||
}
|
||||
options.GOOS = strings.TrimSpace(options.GOOS)
|
||||
if options.GOOS == "" {
|
||||
options.GOOS = runtime.GOOS
|
||||
}
|
||||
options.GOARCH = strings.TrimSpace(options.GOARCH)
|
||||
if options.GOARCH == "" {
|
||||
options.GOARCH = runtime.GOARCH
|
||||
}
|
||||
return options
|
||||
}
|
||||
241
internal/pluginstore/install_test.go
Normal file
241
internal/pluginstore/install_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost"
|
||||
)
|
||||
|
||||
func TestInstallBlocksLoadedWindowsPlugin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
goos string
|
||||
loaded bool
|
||||
wantBlocked bool
|
||||
}{
|
||||
{name: "windows loaded", goos: "windows", loaded: true, wantBlocked: true},
|
||||
{name: "windows not loaded", goos: "windows", loaded: false, wantBlocked: false},
|
||||
{name: "linux loaded", goos: "linux", loaded: true, wantBlocked: false},
|
||||
{name: "darwin loaded", goos: "darwin", loaded: true, wantBlocked: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, errInstall := Client{HTTPClient: failingHTTPDoer{}}.Install(context.Background(), testPlugin(), InstallOptions{
|
||||
PluginsDir: t.TempDir(),
|
||||
GOOS: tt.goos,
|
||||
GOARCH: "amd64",
|
||||
PluginLoaded: func() bool { return tt.loaded },
|
||||
})
|
||||
if errInstall == nil {
|
||||
t.Fatal("Install() error = nil")
|
||||
}
|
||||
if gotBlocked := errors.Is(errInstall, ErrLoadedPluginLocked); gotBlocked != tt.wantBlocked {
|
||||
t.Fatalf("Install() error = %v, blocked = %v, want %v", errInstall, gotBlocked, tt.wantBlocked)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallArchiveBlocksLoadedWindowsPluginBeforeWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, errInstall := InstallArchive(makeZip(t, map[string]string{
|
||||
"sample-provider.dll": "library-data",
|
||||
}), testPlugin(), InstallOptions{
|
||||
PluginsDir: t.TempDir(),
|
||||
GOOS: "windows",
|
||||
GOARCH: "amd64",
|
||||
PluginLoaded: func() bool { return true },
|
||||
})
|
||||
if !errors.Is(errInstall, ErrLoadedPluginLocked) {
|
||||
t.Fatalf("InstallArchive() error = %v, want ErrLoadedPluginLocked", errInstall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallArchiveWritesPlatformPlugin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
result, errInstall := InstallArchive(makeZip(t, map[string]string{
|
||||
"README.md": "ignored",
|
||||
"sample-provider.dylib": "library-data",
|
||||
}), testPlugin(), InstallOptions{PluginsDir: root, GOOS: "darwin", GOARCH: "arm64"})
|
||||
if errInstall != nil {
|
||||
t.Fatalf("InstallArchive() error = %v", errInstall)
|
||||
}
|
||||
wantPath := filepath.Join(root, "darwin", "arm64", "sample-provider.dylib")
|
||||
if result.Path != wantPath {
|
||||
t.Fatalf("Path = %q, want %q", result.Path, wantPath)
|
||||
}
|
||||
data, errRead := os.ReadFile(wantPath)
|
||||
if errRead != nil {
|
||||
t.Fatalf("ReadFile() error = %v", errRead)
|
||||
}
|
||||
if string(data) != "library-data" {
|
||||
t.Fatalf("installed data = %q", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallArchiveReportsOverwrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
targetDir := filepath.Join(root, "darwin", "arm64")
|
||||
if errMkdir := os.MkdirAll(targetDir, 0o755); errMkdir != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", errMkdir)
|
||||
}
|
||||
if errWrite := os.WriteFile(filepath.Join(targetDir, "sample-provider.dylib"), []byte("old"), 0o644); errWrite != nil {
|
||||
t.Fatalf("WriteFile() error = %v", errWrite)
|
||||
}
|
||||
result, errInstall := InstallArchive(makeZip(t, map[string]string{
|
||||
"sample-provider.dylib": "new",
|
||||
}), testPlugin(), InstallOptions{PluginsDir: root, GOOS: "darwin", GOARCH: "arm64"})
|
||||
if errInstall != nil {
|
||||
t.Fatalf("InstallArchive() error = %v", errInstall)
|
||||
}
|
||||
if !result.Overwritten {
|
||||
t.Fatal("Overwritten = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallArchiveOverwritesRuntimeSelectedPlugin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
existingPath := filepath.Join(root, "sample-provider"+pluginhost.PluginExtension(runtime.GOOS))
|
||||
if errWrite := os.WriteFile(existingPath, []byte("old"), 0o644); errWrite != nil {
|
||||
t.Fatalf("WriteFile() error = %v", errWrite)
|
||||
}
|
||||
|
||||
result, errInstall := InstallArchive(makeZip(t, map[string]string{
|
||||
"sample-provider" + pluginhost.PluginExtension(runtime.GOOS): "new",
|
||||
}), testPlugin(), InstallOptions{PluginsDir: root, GOOS: runtime.GOOS, GOARCH: runtime.GOARCH})
|
||||
if errInstall != nil {
|
||||
t.Fatalf("InstallArchive() error = %v", errInstall)
|
||||
}
|
||||
if result.Path != existingPath {
|
||||
t.Fatalf("Path = %q, want selected runtime plugin %q", result.Path, existingPath)
|
||||
}
|
||||
if !result.Overwritten {
|
||||
t.Fatal("Overwritten = false, want true")
|
||||
}
|
||||
data, errRead := os.ReadFile(existingPath)
|
||||
if errRead != nil {
|
||||
t.Fatalf("ReadFile() error = %v", errRead)
|
||||
}
|
||||
if string(data) != "new" {
|
||||
t.Fatalf("installed data = %q, want new", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallArchiveRejectsUnsafeArchives(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
files map[string]string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "zip slip",
|
||||
files: map[string]string{"../sample-provider.dylib": "library"},
|
||||
wantErr: "escapes archive root",
|
||||
},
|
||||
{
|
||||
name: "absolute path",
|
||||
files: map[string]string{"/sample-provider.dylib": "library"},
|
||||
wantErr: "is absolute",
|
||||
},
|
||||
{
|
||||
name: "nested target",
|
||||
files: map[string]string{"nested/sample-provider.dylib": "library"},
|
||||
wantErr: "zip root",
|
||||
},
|
||||
{
|
||||
name: "extension mismatch",
|
||||
files: map[string]string{"sample-provider.so": "library"},
|
||||
wantErr: "sample-provider.dylib",
|
||||
},
|
||||
{
|
||||
name: "filename mismatch",
|
||||
files: map[string]string{"other.dylib": "library"},
|
||||
wantErr: "sample-provider.dylib",
|
||||
},
|
||||
{
|
||||
name: "missing target",
|
||||
files: map[string]string{"README.md": "library"},
|
||||
wantErr: "does not contain",
|
||||
},
|
||||
{
|
||||
name: "multiple targets",
|
||||
files: map[string]string{
|
||||
"sample-provider.dylib": "library",
|
||||
"copy.dylib": "library",
|
||||
},
|
||||
wantErr: "sample-provider.dylib",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, errInstall := InstallArchive(makeZip(t, tt.files), testPlugin(), InstallOptions{PluginsDir: t.TempDir(), GOOS: "darwin", GOARCH: "arm64"})
|
||||
if errInstall == nil {
|
||||
t.Fatal("InstallArchive() error = nil")
|
||||
}
|
||||
if !strings.Contains(errInstall.Error(), tt.wantErr) {
|
||||
t.Fatalf("InstallArchive() error = %v, want substring %q", errInstall, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeZip(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
writer := zip.NewWriter(&buffer)
|
||||
for name, content := range files {
|
||||
file, errCreate := writer.Create(name)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("Create(%s) error = %v", name, errCreate)
|
||||
}
|
||||
if _, errWrite := file.Write([]byte(content)); errWrite != nil {
|
||||
t.Fatalf("Write(%s) error = %v", name, errWrite)
|
||||
}
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("Close() error = %v", errClose)
|
||||
}
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
||||
type failingHTTPDoer struct{}
|
||||
|
||||
func (failingHTTPDoer) Do(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network unavailable")
|
||||
}
|
||||
|
||||
func testPlugin() Plugin {
|
||||
return Plugin{
|
||||
ID: "sample-provider",
|
||||
Name: "Sample Provider",
|
||||
Description: "Adds sample provider support.",
|
||||
Author: "author-name",
|
||||
Version: "0.1.0",
|
||||
Repository: "https://github.com/author-name/cliproxy-sample-provider-plugin",
|
||||
}
|
||||
}
|
||||
156
internal/pluginstore/registry.go
Normal file
156
internal/pluginstore/registry.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultRegistryURL = "https://raw.githubusercontent.com/router-for-me/CLIProxyAPI-Plugins-Store/main/registry.json"
|
||||
SchemaVersion = 1
|
||||
)
|
||||
|
||||
var pluginVersionPattern = regexp.MustCompile(`^[0-9][0-9A-Za-z.+-]*$`)
|
||||
|
||||
type Registry struct {
|
||||
SchemaVersion int `json:"schema_version"`
|
||||
Plugins []Plugin `json:"plugins"`
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
Version string `json:"version"`
|
||||
Repository string `json:"repository"`
|
||||
Logo string `json:"logo,omitempty"`
|
||||
Homepage string `json:"homepage,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
func ParseRegistry(data []byte) (Registry, error) {
|
||||
var registry Registry
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
if errDecode := decoder.Decode(®istry); errDecode != nil {
|
||||
return Registry{}, fmt.Errorf("decode registry: %w", errDecode)
|
||||
}
|
||||
normalizeRegistry(®istry)
|
||||
if errValidate := ValidateRegistry(registry); errValidate != nil {
|
||||
return Registry{}, errValidate
|
||||
}
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
func normalizeRegistry(registry *Registry) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
for index := range registry.Plugins {
|
||||
plugin := ®istry.Plugins[index]
|
||||
plugin.ID = strings.TrimSpace(plugin.ID)
|
||||
plugin.Name = strings.TrimSpace(plugin.Name)
|
||||
plugin.Description = strings.TrimSpace(plugin.Description)
|
||||
plugin.Author = strings.TrimSpace(plugin.Author)
|
||||
plugin.Version = strings.TrimSpace(plugin.Version)
|
||||
plugin.Repository = strings.TrimSpace(plugin.Repository)
|
||||
plugin.Logo = strings.TrimSpace(plugin.Logo)
|
||||
plugin.Homepage = strings.TrimSpace(plugin.Homepage)
|
||||
plugin.License = strings.TrimSpace(plugin.License)
|
||||
for tagIndex := range plugin.Tags {
|
||||
plugin.Tags[tagIndex] = strings.TrimSpace(plugin.Tags[tagIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateRegistry(registry Registry) error {
|
||||
if registry.SchemaVersion != SchemaVersion {
|
||||
return fmt.Errorf("unsupported schema_version %d", registry.SchemaVersion)
|
||||
}
|
||||
seen := make(map[string]struct{}, len(registry.Plugins))
|
||||
for index, plugin := range registry.Plugins {
|
||||
if errValidate := ValidatePlugin(plugin); errValidate != nil {
|
||||
return fmt.Errorf("plugins[%d]: %w", index, errValidate)
|
||||
}
|
||||
id := strings.TrimSpace(plugin.ID)
|
||||
if _, exists := seen[id]; exists {
|
||||
return fmt.Errorf("plugins[%d]: duplicate plugin id %q", index, id)
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePlugin(plugin Plugin) error {
|
||||
required := map[string]string{
|
||||
"id": plugin.ID,
|
||||
"name": plugin.Name,
|
||||
"description": plugin.Description,
|
||||
"author": plugin.Author,
|
||||
"version": plugin.Version,
|
||||
"repository": plugin.Repository,
|
||||
}
|
||||
for field, value := range required {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fmt.Errorf("missing required field %s", field)
|
||||
}
|
||||
}
|
||||
if !pluginhost.ValidatePluginID(strings.TrimSpace(plugin.ID)) {
|
||||
return fmt.Errorf("invalid plugin id %q", plugin.ID)
|
||||
}
|
||||
if !validPluginVersion(strings.TrimSpace(plugin.Version)) {
|
||||
return fmt.Errorf("invalid plugin version %q", plugin.Version)
|
||||
}
|
||||
if _, _, errRepository := GitHubRepositoryParts(plugin.Repository); errRepository != nil {
|
||||
return errRepository
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validPluginVersion(version string) bool {
|
||||
return version != "" && !strings.HasPrefix(version, "v") && pluginVersionPattern.MatchString(version)
|
||||
}
|
||||
|
||||
func GitHubRepositoryParts(repository string) (string, string, error) {
|
||||
repository = strings.TrimSpace(repository)
|
||||
parsed, errParse := url.Parse(repository)
|
||||
if errParse != nil {
|
||||
return "", "", fmt.Errorf("invalid repository URL: %w", errParse)
|
||||
}
|
||||
if parsed.Scheme != "https" || parsed.Host != "github.com" || parsed.RawQuery != "" || parsed.Fragment != "" {
|
||||
return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}")
|
||||
}
|
||||
segments := strings.Split(strings.Trim(parsed.EscapedPath(), "/"), "/")
|
||||
if len(segments) != 2 || segments[0] == "" || segments[1] == "" {
|
||||
return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}")
|
||||
}
|
||||
owner, errOwner := url.PathUnescape(segments[0])
|
||||
if errOwner != nil {
|
||||
return "", "", fmt.Errorf("invalid repository owner: %w", errOwner)
|
||||
}
|
||||
repo, errRepo := url.PathUnescape(segments[1])
|
||||
if errRepo != nil {
|
||||
return "", "", fmt.Errorf("invalid repository name: %w", errRepo)
|
||||
}
|
||||
if strings.HasSuffix(repo, ".git") {
|
||||
return "", "", fmt.Errorf("repository must be https://github.com/{owner}/{repo}")
|
||||
}
|
||||
return owner, repo, nil
|
||||
}
|
||||
|
||||
func (r Registry) PluginByID(id string) (Plugin, bool) {
|
||||
id = strings.TrimSpace(id)
|
||||
for _, plugin := range r.Plugins {
|
||||
if strings.TrimSpace(plugin.ID) == id {
|
||||
return plugin, true
|
||||
}
|
||||
}
|
||||
return Plugin{}, false
|
||||
}
|
||||
167
internal/pluginstore/registry_test.go
Normal file
167
internal/pluginstore/registry_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRegistryValidatesRegistry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registry, errParse := ParseRegistry([]byte(`{
|
||||
"schema_version": 1,
|
||||
"plugins": [{
|
||||
"id": "sample-provider",
|
||||
"name": "Sample Provider",
|
||||
"description": "Adds sample provider support.",
|
||||
"author": "author-name",
|
||||
"version": "0.1.0",
|
||||
"repository": "https://github.com/author-name/cliproxy-sample-provider-plugin",
|
||||
"logo": "https://example.com/logo.png",
|
||||
"homepage": "https://github.com/author-name/cliproxy-sample-provider-plugin",
|
||||
"license": "MIT",
|
||||
"tags": ["provider"]
|
||||
}]
|
||||
}`))
|
||||
if errParse != nil {
|
||||
t.Fatalf("ParseRegistry() error = %v", errParse)
|
||||
}
|
||||
plugin, ok := registry.PluginByID("sample-provider")
|
||||
if !ok {
|
||||
t.Fatal("PluginByID(sample-provider) missing")
|
||||
}
|
||||
if plugin.Version != "0.1.0" {
|
||||
t.Fatalf("plugin version = %q, want 0.1.0", plugin.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryNormalizesPluginFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registry, errParse := ParseRegistry([]byte(`{
|
||||
"schema_version": 1,
|
||||
"plugins": [{
|
||||
"id": " sample-provider ",
|
||||
"name": " Sample Provider ",
|
||||
"description": " Adds sample provider support. ",
|
||||
"author": " author-name ",
|
||||
"version": " 0.1.0 ",
|
||||
"repository": " https://github.com/author-name/cliproxy-sample-provider-plugin ",
|
||||
"logo": " https://example.com/logo.png ",
|
||||
"homepage": " https://github.com/author-name/cliproxy-sample-provider-plugin ",
|
||||
"license": " MIT ",
|
||||
"tags": [" provider "]
|
||||
}]
|
||||
}`))
|
||||
if errParse != nil {
|
||||
t.Fatalf("ParseRegistry() error = %v", errParse)
|
||||
}
|
||||
plugin, ok := registry.PluginByID("sample-provider")
|
||||
if !ok {
|
||||
t.Fatal("PluginByID(sample-provider) missing")
|
||||
}
|
||||
if plugin.ID != "sample-provider" || plugin.Version != "0.1.0" || plugin.Repository != "https://github.com/author-name/cliproxy-sample-provider-plugin" {
|
||||
t.Fatalf("plugin not normalized: %#v", plugin)
|
||||
}
|
||||
if plugin.Name != "Sample Provider" || plugin.Tags[0] != "provider" {
|
||||
t.Fatalf("plugin display fields not normalized: %#v", plugin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRegistryRejectsInvalidEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
valid := Plugin{
|
||||
ID: "sample-provider",
|
||||
Name: "Sample Provider",
|
||||
Description: "Adds sample provider support.",
|
||||
Author: "author-name",
|
||||
Version: "0.1.0",
|
||||
Repository: "https://github.com/author-name/cliproxy-sample-provider-plugin",
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*Registry)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "schema version",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.SchemaVersion = 2
|
||||
},
|
||||
wantErr: "unsupported schema_version",
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.Plugins[0].Name = ""
|
||||
},
|
||||
wantErr: "missing required field name",
|
||||
},
|
||||
{
|
||||
name: "duplicate id",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.Plugins = append(registry.Plugins, valid)
|
||||
},
|
||||
wantErr: "duplicate plugin id",
|
||||
},
|
||||
{
|
||||
name: "invalid id",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.Plugins[0].ID = "../sample-provider"
|
||||
},
|
||||
wantErr: "invalid plugin id",
|
||||
},
|
||||
{
|
||||
name: "v-prefixed version",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.Plugins[0].Version = "v0.1.0"
|
||||
},
|
||||
wantErr: "invalid plugin version",
|
||||
},
|
||||
{
|
||||
name: "invalid repository",
|
||||
mutate: func(registry *Registry) {
|
||||
registry.Plugins[0].Repository = "https://example.com/author/repo"
|
||||
},
|
||||
wantErr: "repository must be",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registry := Registry{SchemaVersion: 1, Plugins: []Plugin{valid}}
|
||||
tt.mutate(®istry)
|
||||
errValidate := ValidateRegistry(registry)
|
||||
if errValidate == nil {
|
||||
t.Fatal("ValidateRegistry() error = nil")
|
||||
}
|
||||
if !strings.Contains(errValidate.Error(), tt.wantErr) {
|
||||
t.Fatalf("ValidateRegistry() error = %v, want substring %q", errValidate, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubRepositoryPartsRejectsNonRepositoryURLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []string{
|
||||
"http://github.com/owner/repo",
|
||||
"https://github.com/owner",
|
||||
"https://github.com/owner/repo/issues",
|
||||
"https://github.com/owner/repo.git",
|
||||
"https://github.com/owner/repo?tab=readme",
|
||||
}
|
||||
for _, repository := range tests {
|
||||
t.Run(repository, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if _, _, errParse := GitHubRepositoryParts(repository); errParse == nil {
|
||||
t.Fatalf("GitHubRepositoryParts(%q) error = nil", repository)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
69
internal/pluginstore/version.go
Normal file
69
internal/pluginstore/version.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package pluginstore
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UpdateAvailable reports whether latest should be offered as an upgrade over
|
||||
// installed. A leading "v"/"V" is ignored on both sides. Versions are compared
|
||||
// numerically when both are dotted release numbers, so an installed version
|
||||
// newer than the registry one is not reported as an update; otherwise any
|
||||
// difference counts as an update.
|
||||
func UpdateAvailable(installed, latest string) bool {
|
||||
installed = normalizeVersion(installed)
|
||||
latest = normalizeVersion(latest)
|
||||
if installed == "" || latest == "" || installed == latest {
|
||||
return false
|
||||
}
|
||||
comparison, comparable := compareVersions(installed, latest)
|
||||
if !comparable {
|
||||
return true
|
||||
}
|
||||
return comparison < 0
|
||||
}
|
||||
|
||||
func normalizeVersion(version string) string {
|
||||
version = strings.TrimSpace(version)
|
||||
if len(version) > 1 && (version[0] == 'v' || version[0] == 'V') {
|
||||
version = version[1:]
|
||||
}
|
||||
return version
|
||||
}
|
||||
|
||||
// compareVersions compares dotted numeric versions segment by segment, with
|
||||
// missing segments treated as zero. It reports false when either version
|
||||
// contains a non-numeric segment.
|
||||
func compareVersions(a, b string) (int, bool) {
|
||||
segmentsA := strings.Split(a, ".")
|
||||
segmentsB := strings.Split(b, ".")
|
||||
length := len(segmentsA)
|
||||
if len(segmentsB) > length {
|
||||
length = len(segmentsB)
|
||||
}
|
||||
for index := 0; index < length; index++ {
|
||||
numberA, okA := versionSegment(segmentsA, index)
|
||||
numberB, okB := versionSegment(segmentsB, index)
|
||||
if !okA || !okB {
|
||||
return 0, false
|
||||
}
|
||||
if numberA != numberB {
|
||||
if numberA < numberB {
|
||||
return -1, true
|
||||
}
|
||||
return 1, true
|
||||
}
|
||||
}
|
||||
return 0, true
|
||||
}
|
||||
|
||||
func versionSegment(segments []string, index int) (int64, bool) {
|
||||
if index >= len(segments) {
|
||||
return 0, true
|
||||
}
|
||||
number, errParse := strconv.ParseInt(segments[index], 10, 64)
|
||||
if errParse != nil || number < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return number, true
|
||||
}
|
||||
34
internal/pluginstore/version_test.go
Normal file
34
internal/pluginstore/version_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package pluginstore
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUpdateAvailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
installed string
|
||||
latest string
|
||||
want bool
|
||||
}{
|
||||
{name: "unknown installed", installed: "", latest: "0.2.0", want: false},
|
||||
{name: "same version", installed: "0.1.0", latest: "0.1.0", want: false},
|
||||
{name: "same version with v prefix", installed: "v0.1.0", latest: "0.1.0", want: false},
|
||||
{name: "newer registry version", installed: "0.1.0", latest: "0.2.0", want: true},
|
||||
{name: "newer registry version with v prefix", installed: "v0.1.0", latest: "0.2.0", want: true},
|
||||
{name: "numeric not lexicographic", installed: "0.1.9", latest: "0.1.10", want: true},
|
||||
{name: "installed newer than registry", installed: "0.2.0", latest: "0.1.0", want: false},
|
||||
{name: "missing segments treated as zero", installed: "0.1", latest: "0.1.0", want: false},
|
||||
{name: "prerelease falls back to inequality", installed: "0.1.0-rc1", latest: "0.1.0", want: true},
|
||||
{name: "non numeric falls back to inequality", installed: "dev", latest: "0.1.0", want: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := UpdateAvailable(tt.installed, tt.latest); got != tt.want {
|
||||
t.Fatalf("UpdateAvailable(%q, %q) = %v, want %v", tt.installed, tt.latest, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -339,7 +339,7 @@ func normalizeLevels(levels []string) []string {
|
||||
// These providers may also support level-based thinking (hybrid models).
|
||||
func isBudgetCapableProvider(provider string) bool {
|
||||
switch provider {
|
||||
case "gemini", "gemini-cli", "antigravity", "claude", "qoder":
|
||||
case "gemini", "gemini-cli", "antigravity", "claude":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -175,10 +175,10 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
|
||||
func TestOAuthModelAliasChannel_PluginProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel(" Qoder ", "oauth"); got != "qoder" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "qoder")
|
||||
if got := OAuthModelAliasChannel(" Sample-Provider ", "oauth"); got != "sample-provider" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "sample-provider")
|
||||
}
|
||||
if got := OAuthModelAliasChannel("qoder", "api_key"); got != "" {
|
||||
if got := OAuthModelAliasChannel("sample-provider", "api_key"); got != "" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want empty channel for API key", got)
|
||||
}
|
||||
}
|
||||
@@ -206,18 +206,18 @@ func TestApplyOAuthModelAlias_PluginProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
aliases := map[string][]internalconfig.OAuthModelAlias{
|
||||
"qoder": {{Name: "qmodel_latest", Alias: "qlatest"}},
|
||||
"sample-provider": {{Name: "sample-model-latest", Alias: "sample-latest"}},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(&internalconfig.Config{})
|
||||
mgr.SetOAuthModelAlias(aliases)
|
||||
|
||||
auth := &Auth{ID: "qoder-auth", Provider: "qoder", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
auth := &Auth{ID: "sample-provider-auth", Provider: "sample-provider", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
|
||||
resolvedModel := mgr.applyOAuthModelAlias(auth, "qlatest")
|
||||
if resolvedModel != "qmodel_latest" {
|
||||
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "qmodel_latest")
|
||||
resolvedModel := mgr.applyOAuthModelAlias(auth, "sample-latest")
|
||||
if resolvedModel != "sample-model-latest" {
|
||||
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "sample-model-latest")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,17 +225,17 @@ func TestApplyOAuthModelAlias_PluginProviderSkipsAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
aliases := map[string][]internalconfig.OAuthModelAlias{
|
||||
"qoder": {{Name: "qmodel_latest", Alias: "qlatest"}},
|
||||
"sample-provider": {{Name: "sample-model-latest", Alias: "sample-latest"}},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(&internalconfig.Config{})
|
||||
mgr.SetOAuthModelAlias(aliases)
|
||||
|
||||
auth := &Auth{ID: "qoder-auth", Provider: "qoder", Attributes: map[string]string{"auth_kind": "api_key"}}
|
||||
auth := &Auth{ID: "sample-provider-auth", Provider: "sample-provider", Attributes: map[string]string{"auth_kind": "api_key"}}
|
||||
|
||||
resolvedModel := mgr.applyOAuthModelAlias(auth, "qlatest")
|
||||
if resolvedModel != "qlatest" {
|
||||
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "qlatest")
|
||||
resolvedModel := mgr.applyOAuthModelAlias(auth, "sample-latest")
|
||||
if resolvedModel != "sample-latest" {
|
||||
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "sample-latest")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,41 +94,41 @@ func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
|
||||
func TestApplyOAuthModelAlias_PluginProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"qoder": {
|
||||
{Name: "qmodel_latest", Alias: "qlatest"},
|
||||
"sample-provider": {
|
||||
{Name: "sample-model-latest", Alias: "sample-latest"},
|
||||
},
|
||||
},
|
||||
}
|
||||
models := []*ModelInfo{
|
||||
{ID: "qmodel_latest", Name: "models/qmodel_latest"},
|
||||
{ID: "sample-model-latest", Name: "models/sample-model-latest"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelAlias(cfg, "qoder", "oauth", models)
|
||||
out := applyOAuthModelAlias(cfg, "sample-provider", "oauth", models)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "qlatest" {
|
||||
t.Fatalf("expected plugin alias id %q, got %q", "qlatest", out[0].ID)
|
||||
if out[0].ID != "sample-latest" {
|
||||
t.Fatalf("expected plugin alias id %q, got %q", "sample-latest", out[0].ID)
|
||||
}
|
||||
if out[0].Name != "models/qlatest" {
|
||||
t.Fatalf("expected plugin alias name %q, got %q", "models/qlatest", out[0].Name)
|
||||
if out[0].Name != "models/sample-latest" {
|
||||
t.Fatalf("expected plugin alias name %q, got %q", "models/sample-latest", out[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_PluginProviderSkipsAPIKey(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"qoder": {
|
||||
{Name: "qmodel_latest", Alias: "qlatest"},
|
||||
"sample-provider": {
|
||||
{Name: "sample-model-latest", Alias: "sample-latest"},
|
||||
},
|
||||
},
|
||||
}
|
||||
models := []*ModelInfo{
|
||||
{ID: "qmodel_latest", Name: "models/qmodel_latest"},
|
||||
{ID: "sample-model-latest", Name: "models/sample-model-latest"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelAlias(cfg, "qoder", "api_key", models)
|
||||
if len(out) != 1 || out[0].ID != "qmodel_latest" {
|
||||
out := applyOAuthModelAlias(cfg, "sample-provider", "api_key", models)
|
||||
if len(out) != 1 || out[0].ID != "sample-model-latest" {
|
||||
t.Fatalf("expected API key plugin model to remain unchanged, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user