Files
CLIProxyAPI/internal/pluginhost/adapters_test.go
Luis Pater fabf06154f feat(access, pluginhost): add support for exclusive frontend auth providers
- Introduced `FrontendAuthProviderExclusive` capability to restrict authentication to a single selected provider.
- Added `SetExclusiveProvider` and `ClearExclusiveProvider` methods for managing exclusive providers in the access registry.
- Updated `pluginhost` to prioritize and enforce exclusive providers based on plugin priority and ID.
- Enhanced RPC capabilities schema to include `FrontendAuthProviderExclusive` field.
- Added example plugin and tests for exclusive frontend auth behavior.
2026-06-09 10:56:58 +08:00

3047 lines
114 KiB
Go

package pluginhost
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v7/sdk/access"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
)
func TestPluginModelInfoToRegistryModelInfoClonesThinkingAndSlices(t *testing.T) {
model := pluginapi.ModelInfo{
ID: "model-1",
Object: "model",
Created: 123,
OwnedBy: "owner",
Type: "plugin",
DisplayName: "Model One",
Name: "provider-model",
Version: "v1",
Description: "desc",
InputTokenLimit: 100,
OutputTokenLimit: 200,
SupportedGenerationMethods: []string{"generate"},
ContextLength: 300,
MaxCompletionTokens: 400,
SupportedParameters: []string{"temperature"},
SupportedInputModalities: []string{"text"},
SupportedOutputModalities: []string{"image"},
Thinking: &pluginapi.ThinkingSupport{
Min: 1,
Max: 2,
ZeroAllowed: true,
DynamicAllowed: true,
Levels: []string{"low", "high"},
},
UserDefined: true,
}
got := pluginModelInfoToRegistryModelInfo(model)
if got.ID != model.ID || got.Object != model.Object || got.Created != model.Created || got.OwnedBy != model.OwnedBy || got.Type != model.Type ||
got.DisplayName != model.DisplayName || got.Name != model.Name || got.Version != model.Version || got.Description != model.Description ||
got.InputTokenLimit != int(model.InputTokenLimit) || got.OutputTokenLimit != int(model.OutputTokenLimit) ||
got.ContextLength != int(model.ContextLength) || got.MaxCompletionTokens != int(model.MaxCompletionTokens) || !got.UserDefined {
t.Fatalf("converted model = %#v, want fields copied from %#v", got, model)
}
if got.Thinking == nil {
t.Fatal("Thinking = nil, want converted thinking support")
}
if got.Thinking.Min != 1 || got.Thinking.Max != 2 || !got.Thinking.ZeroAllowed || !got.Thinking.DynamicAllowed || fmt.Sprint(got.Thinking.Levels) != "[low high]" {
t.Fatalf("Thinking = %#v, want copied thinking support", got.Thinking)
}
model.SupportedGenerationMethods[0] = "mutated"
model.SupportedParameters[0] = "mutated"
model.SupportedInputModalities[0] = "mutated"
model.SupportedOutputModalities[0] = "mutated"
model.Thinking.Levels[0] = "mutated"
if got.SupportedGenerationMethods[0] != "generate" || got.SupportedParameters[0] != "temperature" ||
got.SupportedInputModalities[0] != "text" || got.SupportedOutputModalities[0] != "image" ||
got.Thinking.Levels[0] != "low" {
t.Fatalf("converted model kept aliases to plugin slices: %#v", got)
}
}
func TestRegisterModelsRegistersProviderModelsAndClientID(t *testing.T) {
modelRegistry := newFakeModelRegistry()
host := newHostWithRecords(capabilityRecord{
id: "alpha",
meta: pluginapi.Metadata{Name: "Alpha", Version: "1.0.0"},
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
if req.Plugin.Name != "Alpha" || req.Plugin.Version != "1.0.0" {
t.Fatalf("RegisterModels request plugin = %#v, want Alpha metadata", req.Plugin)
}
return pluginapi.ModelRegistrationResponse{
Provider: " MixedProvider ",
Models: []pluginapi.ModelInfo{{
ID: " model-1 ",
Object: "model",
Created: 123,
OwnedBy: "owner",
Type: "chat",
DisplayName: "Model One",
Name: "native-model-1",
Version: "v1",
Description: "description",
InputTokenLimit: 100,
OutputTokenLimit: 200,
SupportedGenerationMethods: []string{"generate"},
ContextLength: 300,
MaxCompletionTokens: 400,
SupportedParameters: []string{"temperature"},
SupportedInputModalities: []string{"text"},
SupportedOutputModalities: []string{"text"},
Thinking: &pluginapi.ThinkingSupport{
Min: 1,
Max: 2,
ZeroAllowed: true,
DynamicAllowed: true,
Levels: []string{"low"},
},
UserDefined: true,
}},
}, nil
}),
}},
})
host.RegisterModels(context.Background(), modelRegistry)
reg := modelRegistry.clients["plugin:alpha:mixedprovider"]
if reg == nil {
t.Fatal("plugin:alpha:mixedprovider was not registered")
}
if reg.provider != "mixedprovider" {
t.Fatalf("registered provider = %q, want mixedprovider", reg.provider)
}
if len(reg.models) != 1 {
t.Fatalf("registered model count = %d, want 1", len(reg.models))
}
model := reg.models[0]
if model.ID != "model-1" || model.Object != "model" || model.Created != 123 || model.OwnedBy != "owner" || model.Type != "chat" ||
model.DisplayName != "Model One" || model.Name != "native-model-1" || model.Version != "v1" || model.Description != "description" ||
model.InputTokenLimit != 100 || model.OutputTokenLimit != 200 || model.ContextLength != 300 || model.MaxCompletionTokens != 400 ||
model.SupportedGenerationMethods[0] != "generate" || model.SupportedParameters[0] != "temperature" ||
model.SupportedInputModalities[0] != "text" || model.SupportedOutputModalities[0] != "text" || !model.UserDefined {
t.Fatalf("registered model = %#v, want converted fields", model)
}
if model.Thinking == nil || model.Thinking.Min != 1 || model.Thinking.Max != 2 || !model.Thinking.ZeroAllowed ||
!model.Thinking.DynamicAllowed || model.Thinking.Levels[0] != "low" {
t.Fatalf("registered thinking = %#v, want converted thinking", model.Thinking)
}
}
func TestRegisterModelsUsesModelProviderStaticModels(t *testing.T) {
modelRegistry := newFakeModelRegistry()
called := false
host := newHostWithRecords(capabilityRecord{
id: "alpha",
meta: pluginapi.Metadata{Name: "Alpha", Version: "1.0.0"},
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelProvider: modelProviderFunc{
staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
called = true
if req.Plugin.Name != "Alpha" || req.Plugin.Version != "1.0.0" {
t.Fatalf("StaticModels request plugin = %#v, want Alpha metadata", req.Plugin)
}
if req.Host.AuthDir != "/tmp/plugin-auth" || req.Host.ProxyURL != "http://proxy.local" || !req.Host.ForceModelPrefix {
t.Fatalf("StaticModels host = %#v, want configured summary", req.Host)
}
if len(req.Host.OAuthModelAlias["plugin-provider"]) != 1 || req.Host.OAuthModelAlias["plugin-provider"][0].Alias != "alias-model" {
t.Fatalf("StaticModels OAuthModelAlias = %#v, want configured alias", req.Host.OAuthModelAlias)
}
if len(req.Host.ExcludedModels["plugin-provider"]) != 1 || req.Host.ExcludedModels["plugin-provider"][0] != "hidden-model" {
t.Fatalf("StaticModels ExcludedModels = %#v, want configured exclusion", req.Host.ExcludedModels)
}
return pluginapi.ModelResponse{
Provider: " Plugin-Provider ",
Models: []pluginapi.ModelInfo{{
ID: " model-static ",
Object: "model",
DisplayName: "Static Model",
}},
}, nil
},
},
ModelRegistrar: staticModelRegistrar("legacy-provider", "legacy-model"),
}},
})
host.runtimeConfig = &config.Config{
SDKConfig: config.SDKConfig{
ProxyURL: "http://proxy.local",
ForceModelPrefix: true,
},
AuthDir: "/tmp/plugin-auth",
OAuthModelAlias: map[string][]config.OAuthModelAlias{
"plugin-provider": []config.OAuthModelAlias{{Name: "upstream-model", Alias: "alias-model"}},
},
OAuthExcludedModels: map[string][]string{
"plugin-provider": []string{"hidden-model"},
},
}
host.RegisterModels(context.Background(), modelRegistry)
if !called {
t.Fatal("ModelProvider.StaticModels was not called")
}
reg := modelRegistry.clients["plugin:alpha:plugin-provider"]
if reg == nil {
t.Fatal("plugin:alpha:plugin-provider was not registered")
}
if reg.provider != "plugin-provider" {
t.Fatalf("registered provider = %q, want plugin-provider", reg.provider)
}
if len(reg.models) != 1 || reg.models[0].ID != "model-static" || reg.models[0].DisplayName != "Static Model" {
t.Fatalf("registered models = %#v, want static model", reg.models)
}
if _, okLegacy := modelRegistry.clients["plugin:alpha:legacy-provider"]; okLegacy {
t.Fatal("legacy ModelRegistrar path was used despite ModelProvider.StaticModels")
}
}
func TestRegisterModelsSkipsErrorEmptyAndInvalidModels(t *testing.T) {
modelRegistry := newFakeModelRegistry()
host := newHostWithRecords(
capabilityRecord{
id: "error",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return pluginapi.ModelRegistrationResponse{}, errors.New("register failed")
}),
}},
},
capabilityRecord{
id: "empty-provider",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return pluginapi.ModelRegistrationResponse{Provider: " ", Models: []pluginapi.ModelInfo{{ID: "model"}}}, nil
}),
}},
},
capabilityRecord{
id: "empty-models",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return pluginapi.ModelRegistrationResponse{Provider: "provider"}, nil
}),
}},
},
capabilityRecord{
id: "invalid-models",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return pluginapi.ModelRegistrationResponse{Provider: "provider", Models: []pluginapi.ModelInfo{{ID: " "}}}, nil
}),
}},
},
)
host.RegisterModels(context.Background(), modelRegistry)
if len(modelRegistry.clients) != 0 {
t.Fatalf("registered clients = %#v, want none", modelRegistry.clients)
}
}
func TestRegisterModelsPrunesStaleClientAfterSnapshotChange(t *testing.T) {
modelRegistry := newFakeModelRegistry()
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("provider-a", "model-a"),
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "bravo",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("provider-b", "model-b"),
}},
}}})
host.RegisterModels(context.Background(), modelRegistry)
if _, okClient := modelRegistry.clients["plugin:alpha:provider-a"]; okClient {
t.Fatal("stale alpha client is still registered")
}
if modelRegistry.unregisters[0] != "plugin:alpha:provider-a" {
t.Fatalf("unregistered clients = %#v, want alpha client first", modelRegistry.unregisters)
}
if _, okClient := modelRegistry.clients["plugin:bravo:provider-b"]; !okClient {
t.Fatal("bravo client was not registered")
}
}
func TestRegisterModelsDropsResultsWhenSnapshotChangesDuringRegistration(t *testing.T) {
modelRegistry := newFakeModelRegistry()
host := New()
oldSnap := &Snapshot{enabled: true, records: []capabilityRecord{{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "bravo",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("provider-b", "model-b"),
}},
}}})
return pluginapi.ModelRegistrationResponse{
Provider: "provider-a",
Models: []pluginapi.ModelInfo{{
ID: "model-a",
}},
}, nil
}),
}},
}}}
host.snapshot.Store(oldSnap)
host.modelProviders["alpha"] = "existing-provider"
host.RegisterModels(context.Background(), modelRegistry)
if len(modelRegistry.clients) != 0 {
t.Fatalf("registered clients = %#v, want none after stale snapshot", modelRegistry.clients)
}
if len(modelRegistry.unregisters) != 0 {
t.Fatalf("unregistered clients = %#v, want none after stale snapshot", modelRegistry.unregisters)
}
if host.modelProvider("alpha") != "existing-provider" {
t.Fatalf("model provider = %q, want existing-provider", host.modelProvider("alpha"))
}
}
func TestRegisterModelsPanicFusesPluginAndSkipsLaterCalls(t *testing.T) {
calls := 0
modelRegistry := newFakeModelRegistry()
host := newHostWithRecords(capabilityRecord{
id: "panic-plugin",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
calls++
panic("register models panic")
}),
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterModels(context.Background(), modelRegistry)
if calls != 1 {
t.Fatalf("RegisterModels calls = %d, want 1", calls)
}
if !host.isPluginFused("panic-plugin") {
t.Fatal("panic-plugin was not fused")
}
if len(modelRegistry.clients) != 0 {
t.Fatalf("registered clients = %#v, want none", modelRegistry.clients)
}
}
func TestRegisterExecutorsDoesNotOverwriteExistingExecutor(t *testing.T) {
manager := newFakeExecutorManager()
existing := &fakeProviderExecutor{provider: "provider"}
manager.RegisterExecutor(existing)
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: &fakeExecutor{identifier: "provider"},
}},
})
host.RegisterExecutors(manager, nil)
if manager.registerCalls != 1 {
t.Fatalf("RegisterExecutor calls = %d, want only existing registration", manager.registerCalls)
}
got, _ := manager.Executor("provider")
if got != existing {
t.Fatalf("registered executor = %#v, want existing executor", got)
}
}
func TestRegisterExecutorsSameProviderKeepsFirstSnapshotCandidate(t *testing.T) {
manager := newFakeExecutorManager()
first := &fakeExecutor{identifier: "provider"}
second := &fakeExecutor{identifier: "provider"}
host := newHostWithRecords(
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: second,
}},
},
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: first,
}},
},
)
host.RegisterExecutors(manager, nil)
if manager.registerCalls != 1 {
t.Fatalf("RegisterExecutor calls = %d, want 1", manager.registerCalls)
}
adapter, okAdapter := manager.executors["provider"].(*executorAdapter)
if !okAdapter {
t.Fatalf("registered executor = %#v, want executorAdapter", manager.executors["provider"])
}
if adapter.pluginID != "high" || adapter.executor != first {
t.Fatalf("registered adapter = %#v, want high priority executor", adapter)
}
}
func TestRegisterExecutorsIdentifierPanicFusesPlugin(t *testing.T) {
manager := newFakeExecutorManager()
host := newHostWithRecords(capabilityRecord{
id: "panic-identifier",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: &fakeExecutor{panicIdentifier: true},
}},
})
host.RegisterExecutors(manager, nil)
if !host.isPluginFused("panic-identifier") {
t.Fatal("panic-identifier was not fused")
}
if manager.registerCalls != 0 {
t.Fatalf("RegisterExecutor calls = %d, want 0", manager.registerCalls)
}
}
func TestRegisterExecutorsSelectsHighestPriorityPluginExecutorPerModel(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
host := newHostWithRecords(
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("low-provider", "shared-model"),
Executor: &fakeExecutor{identifier: "low-provider"},
}},
},
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("high-provider", "shared-model"),
Executor: &fakeExecutor{identifier: "high-provider"},
}},
},
)
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
if _, okLow := manager.executors["low-provider"]; okLow {
t.Fatal("low priority executor was registered for shared-model")
}
if _, okHigh := manager.executors["high-provider"]; !okHigh {
t.Fatal("high priority executor was not registered for shared-model")
}
if got := host.ModelsForProvider("low-provider"); len(got) != 0 {
t.Fatalf("low provider models = %#v, want none", got)
}
got := host.ModelsForProvider("high-provider")
if len(got) != 1 || got[0].ID != "shared-model" {
t.Fatalf("high provider models = %#v, want shared-model", got)
}
}
func TestRegisterExecutorsKeepsPluginModelsForNativeProviderWithoutOverwritingExecutor(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
native := &fakeProviderExecutor{provider: "native-provider"}
manager.RegisterExecutor(native)
host := newHostWithRecords(capabilityRecord{
id: "native-extension",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("native-provider", "native-extension-model"),
Executor: &fakeExecutor{identifier: "native-provider"},
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
if manager.registerCalls != 1 {
t.Fatalf("RegisterExecutor calls = %d, want only native registration", manager.registerCalls)
}
gotExecutor, _ := manager.Executor("native-provider")
if gotExecutor != native {
t.Fatalf("native provider executor = %#v, want native executor", gotExecutor)
}
gotModels := host.ModelsForProvider("native-provider")
if len(gotModels) != 1 || gotModels[0].ID != "native-extension-model" {
t.Fatalf("native provider plugin models = %#v, want native-extension-model", gotModels)
}
}
func TestRegisterExecutorsSkipsPluginModelWhenModelAlreadyHasNativeExecutor(t *testing.T) {
modelRegistry := newFakeModelRegistry()
modelRegistry.RegisterClient("native-auth", "native-provider", []*registry.ModelInfo{{ID: "shared-model"}})
manager := newFakeExecutorManager()
manager.RegisterExecutor(&fakeProviderExecutor{provider: "native-provider"})
host := newHostWithRecords(capabilityRecord{
id: "plugin-executor",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("plugin-provider", "shared-model"),
Executor: &fakeExecutor{identifier: "plugin-provider"},
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
if _, okPlugin := manager.executors["plugin-provider"]; okPlugin {
t.Fatal("plugin executor was registered for a model that already has a native executor")
}
if got := host.ModelsForProvider("plugin-provider"); len(got) != 0 {
t.Fatalf("plugin provider models = %#v, want none", got)
}
}
func TestRegisterExecutorsUsesRegisteredModelProviderBeforeFallback(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
exec := &fakeExecutor{identifier: "fallback-provider"}
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("registered-provider", "model"),
Executor: exec,
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
adapter, okAdapter := manager.executors["registered-provider"].(*executorAdapter)
if !okAdapter {
t.Fatalf("registered executor = %#v, want executorAdapter", manager.executors["registered-provider"])
}
if adapter.provider != "registered-provider" || adapter.executor != exec {
t.Fatalf("adapter = %#v, want registered provider executor", adapter)
}
if _, okFallback := manager.executors["fallback-provider"]; okFallback {
t.Fatal("fallback provider was registered despite model provider cache")
}
}
func TestRegisterExecutorsExposesExecutorModelsForUserAuthBinding(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
exec := &fakeExecutor{identifier: "plugin-provider"}
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("plugin-provider", "plugin-model"),
Executor: exec,
}},
})
host.RegisterModels(context.Background(), modelRegistry)
if len(modelRegistry.clients) != 0 {
t.Fatalf("registered model clients = %#v, want none until a matching auth binds provider models", modelRegistry.clients)
}
host.RegisterExecutors(manager, modelRegistry)
if _, okExecutor := manager.executors["plugin-provider"]; !okExecutor {
t.Fatal("plugin provider executor was not registered")
}
models := host.ModelsForProvider("plugin-provider")
if len(models) != 1 || models[0].ID != "plugin-model" {
t.Fatalf("provider models = %#v, want plugin-model for user auth binding", models)
}
clientID := pluginExecutorModelClientID("alpha", "plugin-provider")
reg := modelRegistry.clients[clientID]
if reg == nil {
t.Fatalf("executor model client %s was not registered", clientID)
}
if reg.provider != "plugin-provider" || len(reg.models) != 1 || reg.models[0].ID != "plugin-model" {
t.Fatalf("executor model registry client = %#v, want plugin-provider/plugin-model", reg)
}
if providers := modelRegistry.GetModelProviders("plugin-model"); len(providers) != 1 || providers[0] != "plugin-provider" {
t.Fatalf("providers for plugin-model = %#v, want plugin-provider", providers)
}
}
func TestRegisterExecutorsOAuthScopeSkipsStaticModelClientButRegistersExecutor(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
staticCalled := false
host := newHostWithRecords(capabilityRecord{
id: "qoder",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
AuthProvider: fakeAuthProvider{identifier: "qoder"},
ModelProvider: modelProviderFunc{
staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
staticCalled = true
return pluginapi.ModelResponse{
Provider: "qoder",
Models: []pluginapi.ModelInfo{{ID: "static-model"}},
}, nil
},
modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
return pluginapi.ModelResponse{
Provider: "qoder",
Models: []pluginapi.ModelInfo{{ID: "oauth-model"}},
}, nil
},
},
Executor: &fakeExecutor{identifier: "qoder"},
ExecutorModelScope: pluginapi.ExecutorModelScopeOAuth,
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
if staticCalled {
t.Fatal("StaticModels was called for an OAuth-only executor")
}
if _, okExecutor := manager.executors["qoder"]; !okExecutor {
t.Fatal("OAuth-only executor was not registered")
}
if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("qoder", "qoder")]; okClient {
t.Fatal("OAuth-only executor registered a static model client")
}
if got := host.ModelsForProvider("qoder"); len(got) != 0 {
t.Fatalf("OAuth-only provider models = %#v, want none", got)
}
result := host.ModelsForAuth(context.Background(), &coreauth.Auth{
ID: "qoder-auth",
Provider: "qoder",
})
if !result.Handled || result.Provider != "qoder" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" {
t.Fatalf("OAuth model result = %#v, want oauth-model", result)
}
}
func TestModelsForAuthOAuthScopeFallsBackToExecutorIdentifier(t *testing.T) {
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelProvider: modelProviderFunc{
modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
return pluginapi.ModelResponse{
Provider: "plugin-provider",
Models: []pluginapi.ModelInfo{{ID: "oauth-model"}},
}, nil
},
},
Executor: &fakeExecutor{identifier: "plugin-provider"},
ExecutorModelScope: pluginapi.ExecutorModelScopeOAuth,
}},
})
result := host.ModelsForAuth(context.Background(), &coreauth.Auth{
ID: "plugin-auth",
Provider: "plugin-provider",
})
if !result.Handled || result.Provider != "plugin-provider" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" {
t.Fatalf("OAuth model result = %#v, want executor-identifier match", result)
}
}
func TestRegisterExecutorsStaticScopeSkipsModelsForAuth(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
modelsForAuthCalled := false
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
AuthProvider: fakeAuthProvider{identifier: "plugin-provider"},
ModelProvider: modelProviderFunc{
staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
return pluginapi.ModelResponse{
Provider: "plugin-provider",
Models: []pluginapi.ModelInfo{{ID: "static-model"}},
}, nil
},
modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
modelsForAuthCalled = true
return pluginapi.ModelResponse{
Provider: "plugin-provider",
Models: []pluginapi.ModelInfo{{ID: "oauth-model"}},
}, nil
},
},
Executor: &fakeExecutor{identifier: "plugin-provider"},
ExecutorModelScope: pluginapi.ExecutorModelScopeStatic,
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
clientID := pluginExecutorModelClientID("alpha", "plugin-provider")
reg := modelRegistry.clients[clientID]
if reg == nil || reg.provider != "plugin-provider" || len(reg.models) != 1 || reg.models[0].ID != "static-model" {
t.Fatalf("static executor model client = %#v, want static-model", reg)
}
result := host.ModelsForAuth(context.Background(), &coreauth.Auth{
ID: "plugin-auth",
Provider: "plugin-provider",
})
if result.Handled {
t.Fatalf("static-only executor handled per-auth models: %#v", result)
}
if modelsForAuthCalled {
t.Fatal("ModelsForAuth was called for a static-only executor")
}
}
func TestRegisterExecutorsBothScopeKeepsStaticAndOAuthModels(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
AuthProvider: fakeAuthProvider{identifier: "plugin-provider"},
ModelProvider: modelProviderFunc{
staticModels: func(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
return pluginapi.ModelResponse{
Provider: "plugin-provider",
Models: []pluginapi.ModelInfo{{ID: "static-model"}},
}, nil
},
modelsForAuth: func(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
return pluginapi.ModelResponse{
Provider: "plugin-provider",
Models: []pluginapi.ModelInfo{{ID: "oauth-model"}},
}, nil
},
},
Executor: &fakeExecutor{identifier: "plugin-provider"},
ExecutorModelScope: pluginapi.ExecutorModelScopeBoth,
}},
})
host.RegisterModels(context.Background(), modelRegistry)
host.RegisterExecutors(manager, modelRegistry)
clientID := pluginExecutorModelClientID("alpha", "plugin-provider")
reg := modelRegistry.clients[clientID]
if reg == nil || reg.provider != "plugin-provider" || len(reg.models) != 1 || reg.models[0].ID != "static-model" {
t.Fatalf("both-scope static model client = %#v, want static-model", reg)
}
result := host.ModelsForAuth(context.Background(), &coreauth.Auth{
ID: "plugin-auth",
Provider: "plugin-provider",
})
if !result.Handled || result.Provider != "plugin-provider" || len(result.Models) != 1 || result.Models[0].ID != "oauth-model" {
t.Fatalf("both-scope OAuth model result = %#v, want oauth-model", result)
}
}
func TestRegisterExecutorsDropsResultsWhenSnapshotChangesBeforeCommit(t *testing.T) {
manager := newFakeExecutorManager()
host := New()
staleExecutor := &executorAdapter{
host: host,
pluginID: "stale",
provider: "stale-provider",
}
manager.executors["stale-provider"] = staleExecutor
host.executorProviders["stale-provider"] = struct{}{}
changedSnapshot := false
exec := &fakeExecutor{
identifierFunc: func() string {
if !changedSnapshot {
changedSnapshot = true
host.snapshot.Store(&Snapshot{enabled: true})
}
return "provider-a"
},
}
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: exec,
}},
}}})
host.RegisterExecutors(manager, nil)
if manager.registerCalls != 0 {
t.Fatalf("RegisterExecutor calls = %d, want none for stale snapshot", manager.registerCalls)
}
if _, okProvider := manager.executors["provider-a"]; okProvider {
t.Fatal("provider-a executor was registered from a stale snapshot")
}
if manager.executors["stale-provider"] != staleExecutor {
t.Fatalf("stale-provider executor = %#v, want existing executor preserved", manager.executors["stale-provider"])
}
if _, okProvider := host.executorProviders["stale-provider"]; !okProvider {
t.Fatal("stale-provider ownership was pruned by a stale snapshot")
}
}
func TestRegisterExecutorsFallbackUsesExecutorIdentifier(t *testing.T) {
manager := newFakeExecutorManager()
exec := &fakeExecutor{identifier: " FallbackProvider "}
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: exec,
}},
})
host.RegisterExecutors(manager, nil)
adapter, okAdapter := manager.executors["fallbackprovider"].(*executorAdapter)
if !okAdapter {
t.Fatalf("registered executor = %#v, want fallback executorAdapter", manager.executors["fallbackprovider"])
}
if adapter.provider != "fallbackprovider" || adapter.executor != exec {
t.Fatalf("adapter = %#v, want fallback provider executor", adapter)
}
}
func TestRegisterExecutorsPrunesStaleProviderAfterMigration(t *testing.T) {
modelRegistry := newFakeModelRegistry()
manager := newFakeExecutorManager()
exec := &fakeExecutor{identifier: "fallback-provider"}
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ModelRegistrar: staticModelRegistrar("provider-a", "plugin-model"),
Executor: exec,
}},
})
host.modelProviders["alpha"] = "provider-a"
host.modelRegistrations["alpha"] = pluginModelRegistration{
pluginID: "alpha",
provider: "provider-a",
models: []*registry.ModelInfo{{ID: "plugin-model"}},
hasExecutor: true,
}
host.RegisterExecutors(manager, modelRegistry)
host.modelProviders["alpha"] = "provider-b"
host.modelRegistrations["alpha"] = pluginModelRegistration{
pluginID: "alpha",
provider: "provider-b",
models: []*registry.ModelInfo{{ID: "plugin-model"}},
hasExecutor: true,
}
host.RegisterExecutors(manager, modelRegistry)
if _, okProvider := manager.executors["provider-a"]; okProvider {
t.Fatal("provider-a executor is still registered")
}
if manager.unregisters[0] != "provider-a" {
t.Fatalf("unregistered providers = %#v, want provider-a", manager.unregisters)
}
adapter, okAdapter := manager.executors["provider-b"].(*executorAdapter)
if !okAdapter {
t.Fatalf("provider-b executor = %#v, want executorAdapter", manager.executors["provider-b"])
}
if adapter.executor != exec {
t.Fatalf("provider-b adapter executor = %#v, want migrated executor", adapter.executor)
}
if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("alpha", "provider-a")]; okClient {
t.Fatal("provider-a executor model client is still registered")
}
if _, okClient := modelRegistry.clients[pluginExecutorModelClientID("alpha", "provider-b")]; !okClient {
t.Fatal("provider-b executor model client was not registered")
}
}
func TestRegisterExecutorsDoesNotUnregisterStaleProviderOwnedExternally(t *testing.T) {
manager := newFakeExecutorManager()
exec := &fakeExecutor{identifier: "fallback-provider"}
host := newHostWithRecords(capabilityRecord{
id: "alpha",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
Executor: exec,
}},
})
host.modelProviders["alpha"] = "provider-a"
host.RegisterExecutors(manager, nil)
external := &fakeProviderExecutor{provider: "provider-a"}
manager.executors["provider-a"] = external
host.modelProviders["alpha"] = "provider-b"
host.RegisterExecutors(manager, nil)
if len(manager.unregisters) != 0 {
t.Fatalf("unregistered providers = %#v, want none for external owner", manager.unregisters)
}
if manager.executors["provider-a"] != external {
t.Fatalf("provider-a executor = %#v, want external executor", manager.executors["provider-a"])
}
if _, okProvider := manager.executors["provider-b"]; !okProvider {
t.Fatal("provider-b executor was not registered")
}
}
func TestNormalizeRequestChainsByPriority(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|high")...)}, nil
}),
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|low")...)}, nil
}),
}},
},
)
got := host.NormalizeRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("start"), false)
if string(got) != "start|high|low" {
t.Fatalf("NormalizeRequest() = %q, want %q", got, "start|high|low")
}
}
func TestTranslateRequestStopsAtFirstSuccessfulCandidate(t *testing.T) {
calls := make([]string, 0, 2)
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
calls = append(calls, "high")
return pluginapi.PayloadResponse{Body: []byte("translated-high")}, nil
}),
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
calls = append(calls, "low")
return pluginapi.PayloadResponse{Body: []byte("translated-low")}, nil
}),
}},
},
)
got, ok := host.TranslateRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("input"), false)
if !ok {
t.Fatal("TranslateRequest() ok = false, want true")
}
if string(got) != "translated-high" {
t.Fatalf("TranslateRequest() = %q, want %q", got, "translated-high")
}
if fmt.Sprint(calls) != "[high]" {
t.Fatalf("calls = %v, want [high]", calls)
}
}
func TestAdaptersKeepPayloadOrTryNextOnErrorAndEmptyBody(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "normalizer-error",
priority: 30,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, fmt.Errorf("normalize failed")
}),
}},
},
capabilityRecord{
id: "normalizer-empty",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, nil
}),
}},
},
capabilityRecord{
id: "normalizer-success",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: []byte("kept-then-success")}, nil
}),
}},
},
)
normalized := host.NormalizeRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("original"), false)
if string(normalized) != "kept-then-success" {
t.Fatalf("NormalizeRequest() = %q, want %q", normalized, "kept-then-success")
}
translatorHost := newHostWithRecords(
capabilityRecord{
id: "translator-error",
priority: 30,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, fmt.Errorf("translate failed")
}),
}},
},
capabilityRecord{
id: "translator-empty",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, nil
}),
}},
},
capabilityRecord{
id: "translator-success",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: []byte("translated")}, nil
}),
}},
},
)
translated, ok := translatorHost.TranslateRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("original"), false)
if !ok {
t.Fatal("TranslateRequest() ok = false, want true")
}
if string(translated) != "translated" {
t.Fatalf("TranslateRequest() = %q, want %q", translated, "translated")
}
}
func TestTranslatorPanicFusesPlugin(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "panic-plugin",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
panic("normalize panic")
}),
}},
},
capabilityRecord{
id: "next-plugin",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestNormalizer: requestNormalizerFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|next")...)}, nil
}),
}},
},
)
got := host.NormalizeRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("original"), false)
if string(got) != "original|next" {
t.Fatalf("NormalizeRequest() = %q, want %q", got, "original|next")
}
if !host.isPluginFused("panic-plugin") {
t.Fatal("panic-plugin was not fused")
}
}
func TestTranslatorPanicFusesEveryHookPath(t *testing.T) {
cases := []struct {
name string
pluginID string
call func(*Host) ([]byte, bool)
}{
{
name: "request translator",
pluginID: "request-translator-panic",
call: func(host *Host) ([]byte, bool) {
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "request-translator-panic",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestTranslator: requestTranslatorFunc(func(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
panic("request translator panic")
}),
}},
}}})
return host.TranslateRequest(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("body"), false)
},
},
{
name: "response before normalizer",
pluginID: "response-before-panic",
call: func(host *Host) ([]byte, bool) {
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "response-before-panic",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
panic("response before panic")
}),
}},
}}})
return host.NormalizeResponseBefore(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("body"), false), false
},
},
{
name: "response translator",
pluginID: "response-translator-panic",
call: func(host *Host) ([]byte, bool) {
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "response-translator-panic",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
panic("response translator panic")
}),
}},
}}})
return host.TranslateResponse(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("body"), false)
},
},
{
name: "response after normalizer",
pluginID: "response-after-panic",
call: func(host *Host) ([]byte, bool) {
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "response-after-panic",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
panic("response after panic")
}),
}},
}}})
return host.NormalizeResponseAfter(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("body"), false), false
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
host := New()
got, _ := tt.call(host)
if string(got) != "body" {
t.Fatalf("hook result = %q, want original body", got)
}
if !host.isPluginFused(tt.pluginID) {
t.Fatalf("%s was not fused", tt.pluginID)
}
})
}
}
func TestResponseNormalizersChainByPriority(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|before-high")...)}, nil
}),
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|after-high")...)}, nil
}),
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|before-low")...)}, nil
}),
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: append(req.Body, []byte("|after-low")...)}, nil
}),
}},
},
)
before := host.NormalizeResponseBefore(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("original-request"), []byte("translated-request"), []byte("body"), true)
if string(before) != "body|before-high|before-low" {
t.Fatalf("NormalizeResponseBefore() = %q, want %q", before, "body|before-high|before-low")
}
after := host.NormalizeResponseAfter(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", []byte("original-request"), []byte("translated-request"), []byte("body"), true)
if string(after) != "body|after-high|after-low" {
t.Fatalf("NormalizeResponseAfter() = %q, want %q", after, "body|after-high|after-low")
}
}
func TestTranslateResponseStopsAtFirstSuccessfulCandidate(t *testing.T) {
calls := make([]string, 0, 2)
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
calls = append(calls, "high")
return pluginapi.PayloadResponse{Body: []byte("response-high")}, nil
}),
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
calls = append(calls, "low")
return pluginapi.PayloadResponse{Body: []byte("response-low")}, nil
}),
}},
},
)
got, ok := host.TranslateResponse(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("input"), false)
if !ok {
t.Fatal("TranslateResponse() ok = false, want true")
}
if string(got) != "response-high" {
t.Fatalf("TranslateResponse() = %q, want %q", got, "response-high")
}
if fmt.Sprint(calls) != "[high]" {
t.Fatalf("calls = %v, want [high]", calls)
}
}
func TestInterceptRequestChainsByPriorityAndHeaders(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
if req.SourceFormat != "openai" || req.Model != "normalized" || req.RequestedModel != "requested" {
t.Fatalf("unexpected request context: %#v", req)
}
return pluginapi.RequestInterceptResponse{
Headers: http.Header{"X-Plugin": []string{"high"}},
Body: append(req.Body, []byte("|high")...),
}, nil
}),
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
return pluginapi.RequestInterceptResponse{
Headers: http.Header{"X-Plugin": []string{"low"}, "X-Low": []string{"1"}},
Body: append(req.Body, []byte("|low")...),
ClearHeaders: []string{"X-Remove"},
}, nil
}),
}},
},
)
headers := http.Header{"X-Remove": []string{"yes"}}
got := host.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{
SourceFormat: "openai",
Model: "normalized",
RequestedModel: "requested",
Stream: false,
Headers: headers,
Body: []byte("start"),
})
if string(got.Body) != "start|high|low" {
t.Fatalf("body = %q, want %q", got.Body, "start|high|low")
}
if got.Headers.Get("X-Plugin") != "low" || got.Headers.Get("X-Low") != "1" || got.Headers.Get("X-Remove") != "" {
t.Fatalf("headers = %#v", got.Headers)
}
if headers.Get("X-Plugin") != "" {
t.Fatalf("input headers were mutated: %#v", headers)
}
}
func TestResponseInterceptorsChainAndStreamHistory(t *testing.T) {
var seenHistory [][]byte
var sawSecondResponse bool
var sawSecondStream bool
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseInterceptor: responseInterceptorFunc{
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
return pluginapi.ResponseInterceptResponse{
Headers: http.Header{"X-Response": []string{"high"}},
Body: append(req.Body, []byte("|high")...),
}, nil
},
},
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
seenHistory = req.HistoryChunks
return pluginapi.StreamChunkInterceptResponse{
Headers: http.Header{"X-Stream": []string{"high"}},
Body: append(req.Body, []byte("|high")...),
}, nil
},
},
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseInterceptor: responseInterceptorFunc{
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
if string(req.Body) != "body|high" {
t.Fatalf("second response interceptor body = %q, want body|high", req.Body)
}
if req.ResponseHeaders.Get("X-Response") != "high" {
t.Fatalf("second response interceptor headers = %#v, want high header", req.ResponseHeaders)
}
sawSecondResponse = true
return pluginapi.ResponseInterceptResponse{
Headers: http.Header{"X-Response": []string{"low"}, "X-Low": []string{"1"}},
ClearHeaders: []string{"X-Remove"},
Body: append(req.Body, []byte("|low")...),
}, nil
},
},
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
if string(req.Body) != "chunk|high" {
t.Fatalf("second stream interceptor body = %q, want chunk|high", req.Body)
}
if req.ResponseHeaders.Get("X-Stream") != "high" {
t.Fatalf("second stream interceptor headers = %#v, want high header", req.ResponseHeaders)
}
if len(req.HistoryChunks) != 1 || string(req.HistoryChunks[0]) != "first" {
t.Fatalf("second stream interceptor history = %#v", req.HistoryChunks)
}
seenHistory = req.HistoryChunks
sawSecondStream = true
return pluginapi.StreamChunkInterceptResponse{
Headers: http.Header{"X-Stream": []string{"low"}, "X-Low": []string{"1"}},
ClearHeaders: []string{"X-Remove"},
Body: append(req.Body, []byte("|low")...),
}, nil
},
},
}},
},
)
nonStream := host.InterceptResponse(context.Background(), pluginapi.ResponseInterceptRequest{
SourceFormat: "openai",
Model: "normalized",
RequestedModel: "requested",
ResponseHeaders: http.Header{"Content-Type": []string{"application/json"}, "X-Remove": []string{"yes"}},
Body: []byte("body"),
StatusCode: http.StatusOK,
})
if string(nonStream.Body) != "body|high|low" || nonStream.Headers.Get("X-Response") != "low" || nonStream.Headers.Get("X-Low") != "1" {
t.Fatalf("non-stream result = %#v", nonStream)
}
if nonStream.Headers.Get("X-Remove") != "" {
t.Fatalf("non-stream headers kept cleared value: %#v", nonStream.Headers)
}
if !sawSecondResponse {
t.Fatal("second response interceptor was not called")
}
stream := host.InterceptStreamChunk(context.Background(), pluginapi.StreamChunkInterceptRequest{
SourceFormat: "openai",
Model: "normalized",
RequestedModel: "requested",
ResponseHeaders: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Remove": []string{"yes"}},
Body: []byte("chunk"),
HistoryChunks: [][]byte{[]byte("first")},
ChunkIndex: 1,
})
if string(stream.Body) != "chunk|high|low" || stream.Headers.Get("X-Stream") != "low" || stream.Headers.Get("X-Low") != "1" {
t.Fatalf("stream result = %#v", stream)
}
if stream.Headers.Get("X-Remove") != "" {
t.Fatalf("stream headers kept cleared value: %#v", stream.Headers)
}
if len(seenHistory) != 1 || string(seenHistory[0]) != "first" {
t.Fatalf("history = %#v", seenHistory)
}
if !sawSecondStream {
t.Fatal("second stream interceptor was not called")
}
}
func TestInterceptorsSkipErrorsAndFusePanics(t *testing.T) {
host := newHostWithRecords(
capabilityRecord{
id: "error",
priority: 30,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
return pluginapi.RequestInterceptResponse{}, fmt.Errorf("request failed")
}),
}},
},
capabilityRecord{
id: "panic",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
panic("request panic")
}),
}},
},
capabilityRecord{
id: "success",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
return pluginapi.RequestInterceptResponse{Body: append(req.Body, []byte("|success")...)}, nil
}),
}},
},
)
got := host.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{Body: []byte("body")})
if string(got.Body) != "body|success" {
t.Fatalf("body = %q, want body|success", got.Body)
}
if !host.isPluginFused("panic") {
t.Fatal("panic plugin was not fused")
}
}
func TestStreamInterceptorsDropChunkStopsChain(t *testing.T) {
var lowCalled bool
host := newHostWithRecords(
capabilityRecord{
id: "high",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
return pluginapi.StreamChunkInterceptResponse{
Headers: http.Header{"X-Stream": []string{"high"}},
Body: append(req.Body, []byte("|high")...),
DropChunk: true,
ClearHeaders: nil,
}, nil
},
},
}},
},
capabilityRecord{
id: "low",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
lowCalled = true
return pluginapi.StreamChunkInterceptResponse{
Headers: http.Header{"X-Stream": []string{"low"}},
Body: append(req.Body, []byte("|low")...),
}, nil
},
},
}},
},
)
got := host.InterceptStreamChunk(context.Background(), pluginapi.StreamChunkInterceptRequest{
SourceFormat: "openai",
Model: "normalized",
RequestedModel: "requested",
Body: []byte("chunk"),
})
if lowCalled {
t.Fatal("low-priority stream interceptor should not be called after DropChunk")
}
if !got.DropChunk {
t.Fatal("DropChunk = false, want true")
}
if string(got.Body) != "chunk|high" {
t.Fatalf("body = %q, want chunk|high", got.Body)
}
if got.Headers.Get("X-Stream") != "high" {
t.Fatalf("headers = %#v, want high header", got.Headers)
}
}
func TestHasStreamInterceptorsReflectsActiveStreamInterceptors(t *testing.T) {
requestOnly := newHostWithRecords(capabilityRecord{
id: "request",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
return pluginapi.RequestInterceptResponse{Body: req.Body}, nil
}),
}},
})
if requestOnly.HasStreamInterceptors() {
t.Fatal("HasStreamInterceptors() = true, want false for request-only plugins")
}
responseOnly := newHostWithRecords(capabilityRecord{
id: "response",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseInterceptor: responseInterceptorFunc{
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
return pluginapi.ResponseInterceptResponse{Body: req.Body}, nil
},
},
}},
})
if responseOnly.HasStreamInterceptors() {
t.Fatal("HasStreamInterceptors() = true, want false for response-only plugins")
}
streamHost := newHostWithRecords(capabilityRecord{
id: "stream",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
return pluginapi.StreamChunkInterceptResponse{Body: req.Body}, nil
},
},
}},
})
if !streamHost.HasStreamInterceptors() {
t.Fatal("HasStreamInterceptors() = false, want true for stream interceptors")
}
streamHost.mu.Lock()
streamHost.fused["stream"] = "test fused"
streamHost.mu.Unlock()
if streamHost.HasStreamInterceptors() {
t.Fatal("HasStreamInterceptors() = true, want false after interceptor plugin is fused")
}
}
func TestInterceptorsDoNotMutateInputs(t *testing.T) {
t.Run("request", func(t *testing.T) {
headers := http.Header{"X-Request": []string{"input"}}
metadata := map[string]any{
"nested": map[string]any{"value": "original"},
"items": []any{map[string]any{"value": "original"}},
"strings": []string{"original"},
"bytes": []byte("original"),
"labels": map[string]string{"name": "original"},
"values": url.Values{"name": []string{"original"}},
"mapSlice": map[string][]string{"name": []string{"original"}},
"sliceMap": []map[string]string{{"name": "original"}},
"aliasMap": stringSliceAlias{"original"},
"aliasList": mapSliceAlias{{"name": "original"}},
"key": "value",
}
body := []byte("request-body")
host := newHostWithRecords(capabilityRecord{
id: "request",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
req.Headers.Set("X-Request", "mutated")
req.Body[0] = 'R'
req.Metadata["key"] = "mutated"
req.Metadata["nested"].(map[string]any)["value"] = "mutated"
req.Metadata["items"].([]any)[0].(map[string]any)["value"] = "mutated"
req.Metadata["strings"].([]string)[0] = "mutated"
req.Metadata["bytes"].([]byte)[0] = 'M'
req.Metadata["labels"].(map[string]string)["name"] = "mutated"
req.Metadata["values"].(url.Values)["name"][0] = "mutated"
req.Metadata["mapSlice"].(map[string][]string)["name"][0] = "mutated"
req.Metadata["sliceMap"].([]map[string]string)[0]["name"] = "mutated"
req.Metadata["aliasMap"].(stringSliceAlias)[0] = "mutated"
req.Metadata["aliasList"].(mapSliceAlias)[0]["name"] = "mutated"
return pluginapi.RequestInterceptResponse{Body: append(req.Body, []byte("|ok")...)}, nil
}),
}},
})
got := host.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{
Headers: headers,
Body: body,
Metadata: metadata,
})
if headers.Get("X-Request") != "input" {
t.Fatalf("request headers mutated: %#v", headers)
}
if string(body) != "request-body" {
t.Fatalf("request body mutated: %q", body)
}
if metadata["key"] != "value" {
t.Fatalf("request metadata mutated: %#v", metadata)
}
if metadata["nested"].(map[string]any)["value"] != "original" || metadata["items"].([]any)[0].(map[string]any)["value"] != "original" {
t.Fatalf("request nested metadata mutated: %#v", metadata)
}
if metadata["strings"].([]string)[0] != "original" || string(metadata["bytes"].([]byte)) != "original" || metadata["labels"].(map[string]string)["name"] != "original" {
t.Fatalf("request nested metadata aliases mutated: %#v", metadata)
}
if metadata["values"].(url.Values)["name"][0] != "original" || metadata["mapSlice"].(map[string][]string)["name"][0] != "original" {
t.Fatalf("request map/slice metadata mutated: %#v", metadata)
}
if metadata["sliceMap"].([]map[string]string)[0]["name"] != "original" || metadata["aliasMap"].(stringSliceAlias)[0] != "original" || metadata["aliasList"].(mapSliceAlias)[0]["name"] != "original" {
t.Fatalf("request alias metadata mutated: %#v", metadata)
}
if !strings.HasSuffix(string(got.Body), "|ok") {
t.Fatalf("request result body = %q", got.Body)
}
})
t.Run("response", func(t *testing.T) {
requestHeaders := http.Header{"X-Request": []string{"input"}}
responseHeaders := http.Header{"X-Response": []string{"input"}}
originalRequest := []byte("original")
requestBody := []byte("request")
body := []byte("body")
metadata := map[string]any{
"nested": map[string]any{"value": "original"},
"items": []any{map[string]any{"value": "original"}},
"strings": []string{"original"},
"bytes": []byte("original"),
"labels": map[string]string{"name": "original"},
"values": url.Values{"name": []string{"original"}},
"mapSlice": map[string][]string{"name": []string{"original"}},
"sliceMap": []map[string]string{{"name": "original"}},
"aliasMap": stringSliceAlias{"original"},
"aliasList": mapSliceAlias{{"name": "original"}},
"key": "value",
}
host := newHostWithRecords(capabilityRecord{
id: "response",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseInterceptor: responseInterceptorFunc{
interceptResponse: func(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
req.RequestHeaders.Set("X-Request", "mutated")
req.ResponseHeaders.Set("X-Response", "mutated")
req.OriginalRequest[0] = 'O'
req.RequestBody[0] = 'R'
req.Body[0] = 'B'
req.Metadata["key"] = "mutated"
req.Metadata["nested"].(map[string]any)["value"] = "mutated"
req.Metadata["items"].([]any)[0].(map[string]any)["value"] = "mutated"
req.Metadata["strings"].([]string)[0] = "mutated"
req.Metadata["bytes"].([]byte)[0] = 'M'
req.Metadata["labels"].(map[string]string)["name"] = "mutated"
req.Metadata["values"].(url.Values)["name"][0] = "mutated"
req.Metadata["mapSlice"].(map[string][]string)["name"][0] = "mutated"
req.Metadata["sliceMap"].([]map[string]string)[0]["name"] = "mutated"
req.Metadata["aliasMap"].(stringSliceAlias)[0] = "mutated"
req.Metadata["aliasList"].(mapSliceAlias)[0]["name"] = "mutated"
return pluginapi.ResponseInterceptResponse{Body: append(req.Body, []byte("|ok")...)}, nil
},
},
}},
})
got := host.InterceptResponse(context.Background(), pluginapi.ResponseInterceptRequest{
RequestHeaders: requestHeaders,
ResponseHeaders: responseHeaders,
OriginalRequest: originalRequest,
RequestBody: requestBody,
Body: body,
Metadata: metadata,
})
if requestHeaders.Get("X-Request") != "input" {
t.Fatalf("request headers mutated: %#v", requestHeaders)
}
if responseHeaders.Get("X-Response") != "input" {
t.Fatalf("response headers mutated: %#v", responseHeaders)
}
if string(originalRequest) != "original" {
t.Fatalf("original request mutated: %q", originalRequest)
}
if string(requestBody) != "request" {
t.Fatalf("request body mutated: %q", requestBody)
}
if string(body) != "body" {
t.Fatalf("response body mutated: %q", body)
}
if metadata["key"] != "value" {
t.Fatalf("response metadata mutated: %#v", metadata)
}
if metadata["nested"].(map[string]any)["value"] != "original" || metadata["items"].([]any)[0].(map[string]any)["value"] != "original" {
t.Fatalf("response nested metadata mutated: %#v", metadata)
}
if metadata["strings"].([]string)[0] != "original" || string(metadata["bytes"].([]byte)) != "original" || metadata["labels"].(map[string]string)["name"] != "original" {
t.Fatalf("response nested metadata aliases mutated: %#v", metadata)
}
if metadata["values"].(url.Values)["name"][0] != "original" || metadata["mapSlice"].(map[string][]string)["name"][0] != "original" {
t.Fatalf("response map/slice metadata mutated: %#v", metadata)
}
if metadata["sliceMap"].([]map[string]string)[0]["name"] != "original" || metadata["aliasMap"].(stringSliceAlias)[0] != "original" || metadata["aliasList"].(mapSliceAlias)[0]["name"] != "original" {
t.Fatalf("response alias metadata mutated: %#v", metadata)
}
if !strings.HasSuffix(string(got.Body), "|ok") {
t.Fatalf("response result body = %q", got.Body)
}
})
t.Run("stream", func(t *testing.T) {
requestHeaders := http.Header{"X-Request": []string{"input"}}
responseHeaders := http.Header{"X-Response": []string{"input"}}
originalRequest := []byte("original")
requestBody := []byte("request")
body := []byte("chunk")
history := [][]byte{[]byte("first")}
metadata := map[string]any{
"nested": map[string]any{"value": "original"},
"items": []any{map[string]any{"value": "original"}},
"strings": []string{"original"},
"bytes": []byte("original"),
"labels": map[string]string{"name": "original"},
"values": url.Values{"name": []string{"original"}},
"mapSlice": map[string][]string{"name": []string{"original"}},
"sliceMap": []map[string]string{{"name": "original"}},
"aliasMap": stringSliceAlias{"original"},
"aliasList": mapSliceAlias{{"name": "original"}},
"key": "value",
}
host := newHostWithRecords(capabilityRecord{
id: "stream",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
StreamChunkInterceptor: responseInterceptorFunc{
interceptStreamChunk: func(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
req.RequestHeaders.Set("X-Request", "mutated")
req.ResponseHeaders.Set("X-Response", "mutated")
req.OriginalRequest[0] = 'O'
req.RequestBody[0] = 'R'
req.Body[0] = 'C'
req.HistoryChunks[0][0] = 'F'
req.Metadata["key"] = "mutated"
req.Metadata["nested"].(map[string]any)["value"] = "mutated"
req.Metadata["items"].([]any)[0].(map[string]any)["value"] = "mutated"
req.Metadata["strings"].([]string)[0] = "mutated"
req.Metadata["bytes"].([]byte)[0] = 'M'
req.Metadata["labels"].(map[string]string)["name"] = "mutated"
req.Metadata["values"].(url.Values)["name"][0] = "mutated"
req.Metadata["mapSlice"].(map[string][]string)["name"][0] = "mutated"
req.Metadata["sliceMap"].([]map[string]string)[0]["name"] = "mutated"
req.Metadata["aliasMap"].(stringSliceAlias)[0] = "mutated"
req.Metadata["aliasList"].(mapSliceAlias)[0]["name"] = "mutated"
return pluginapi.StreamChunkInterceptResponse{Body: append(req.Body, []byte("|ok")...)}, nil
},
},
}},
})
got := host.InterceptStreamChunk(context.Background(), pluginapi.StreamChunkInterceptRequest{
RequestHeaders: requestHeaders,
ResponseHeaders: responseHeaders,
OriginalRequest: originalRequest,
RequestBody: requestBody,
Body: body,
HistoryChunks: history,
Metadata: metadata,
})
if requestHeaders.Get("X-Request") != "input" {
t.Fatalf("request headers mutated: %#v", requestHeaders)
}
if responseHeaders.Get("X-Response") != "input" {
t.Fatalf("response headers mutated: %#v", responseHeaders)
}
if string(originalRequest) != "original" {
t.Fatalf("original request mutated: %q", originalRequest)
}
if string(requestBody) != "request" {
t.Fatalf("request body mutated: %q", requestBody)
}
if string(body) != "chunk" {
t.Fatalf("stream body mutated: %q", body)
}
if string(history[0]) != "first" {
t.Fatalf("history mutated: %#v", history)
}
if metadata["key"] != "value" {
t.Fatalf("stream metadata mutated: %#v", metadata)
}
if metadata["nested"].(map[string]any)["value"] != "original" || metadata["items"].([]any)[0].(map[string]any)["value"] != "original" {
t.Fatalf("stream nested metadata mutated: %#v", metadata)
}
if metadata["strings"].([]string)[0] != "original" || string(metadata["bytes"].([]byte)) != "original" || metadata["labels"].(map[string]string)["name"] != "original" {
t.Fatalf("stream nested metadata aliases mutated: %#v", metadata)
}
if metadata["values"].(url.Values)["name"][0] != "original" || metadata["mapSlice"].(map[string][]string)["name"][0] != "original" {
t.Fatalf("stream map/slice metadata mutated: %#v", metadata)
}
if metadata["sliceMap"].([]map[string]string)[0]["name"] != "original" || metadata["aliasMap"].(stringSliceAlias)[0] != "original" || metadata["aliasList"].(mapSliceAlias)[0]["name"] != "original" {
t.Fatalf("stream alias metadata mutated: %#v", metadata)
}
if !strings.HasSuffix(string(got.Body), "|ok") {
t.Fatalf("stream result body = %q", got.Body)
}
})
t.Run("pointers-and-cycle", func(t *testing.T) {
type pointerMetadata struct {
Value string
Items []string
}
structValue := &pointerMetadata{Value: "original", Items: []string{"original"}}
mapValue := &map[string][]string{"names": []string{"original"}}
sliceValue := &[]string{"original"}
aliasMapValue := &mapSliceAlias{{"name": "original"}}
var ifaceValue any = &pointerMetadata{Value: "original", Items: []string{"original"}}
cycle := map[string]any{}
cycle["self"] = cycle
metadata := map[string]any{
"struct_ptr": structValue,
"map_ptr": mapValue,
"slice_ptr": sliceValue,
"alias_ptr": aliasMapValue,
"iface_ptr": ifaceValue,
"cycle": cycle,
}
host := newHostWithRecords(capabilityRecord{
id: "pointer",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
RequestInterceptor: requestInterceptorFunc(func(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
req.Metadata["struct_ptr"].(*pointerMetadata).Value = "mutated"
req.Metadata["struct_ptr"].(*pointerMetadata).Items[0] = "mutated"
(*req.Metadata["map_ptr"].(*map[string][]string))["names"][0] = "mutated"
(*req.Metadata["slice_ptr"].(*[]string))[0] = "mutated"
(*req.Metadata["alias_ptr"].(*mapSliceAlias))[0]["name"] = "mutated"
req.Metadata["iface_ptr"].(*pointerMetadata).Value = "mutated"
if clonedCycle, ok := req.Metadata["cycle"].(map[string]any); ok {
clonedCycle["marker"] = "mutated"
clonedCycle["self"] = "mutated"
}
return pluginapi.RequestInterceptResponse{Body: []byte("ok")}, nil
}),
}},
})
_ = host.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{Metadata: metadata})
if structValue.Value != "original" || structValue.Items[0] != "original" {
t.Fatalf("struct pointer metadata mutated: %#v", structValue)
}
if (*mapValue)["names"][0] != "original" {
t.Fatalf("map pointer metadata mutated: %#v", mapValue)
}
if (*sliceValue)[0] != "original" {
t.Fatalf("slice pointer metadata mutated: %#v", sliceValue)
}
if (*aliasMapValue)[0]["name"] != "original" {
t.Fatalf("alias pointer metadata mutated: %#v", aliasMapValue)
}
if ifaceStruct, ok := ifaceValue.(*pointerMetadata); !ok || ifaceStruct.Value != "original" || ifaceStruct.Items[0] != "original" {
t.Fatalf("interface pointer metadata mutated: %#v", ifaceValue)
}
if _, ok := cycle["self"].(map[string]any); !ok {
t.Fatalf("cycle metadata structure changed unexpectedly: %#v", cycle)
}
if _, ok := cycle["marker"]; ok {
t.Fatalf("cycle metadata mutated: %#v", cycle)
}
})
}
func TestResponseHooksKeepPayloadOrTryNextOnErrorAndEmptyBody(t *testing.T) {
normalizerHost := newHostWithRecords(
capabilityRecord{
id: "before-error",
priority: 30,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, fmt.Errorf("before failed")
}),
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, fmt.Errorf("after failed")
}),
}},
},
capabilityRecord{
id: "before-empty",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, nil
}),
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, nil
}),
}},
},
capabilityRecord{
id: "before-success",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseBeforeTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: []byte("before-success")}, nil
}),
ResponseAfterTranslator: responseNormalizerFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: []byte("after-success")}, nil
}),
}},
},
)
before := normalizerHost.NormalizeResponseBefore(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("original"), false)
if string(before) != "before-success" {
t.Fatalf("NormalizeResponseBefore() = %q, want %q", before, "before-success")
}
after := normalizerHost.NormalizeResponseAfter(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("original"), false)
if string(after) != "after-success" {
t.Fatalf("NormalizeResponseAfter() = %q, want %q", after, "after-success")
}
translatorHost := newHostWithRecords(
capabilityRecord{
id: "translator-error",
priority: 30,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, fmt.Errorf("translate failed")
}),
}},
},
capabilityRecord{
id: "translator-empty",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{}, nil
}),
}},
},
capabilityRecord{
id: "translator-success",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return pluginapi.PayloadResponse{Body: []byte("response-translated")}, nil
}),
}},
},
)
translated, ok := translatorHost.TranslateResponse(context.Background(), sdktranslator.FormatOpenAI, sdktranslator.FormatClaude, "model", nil, nil, []byte("original"), false)
if !ok {
t.Fatal("TranslateResponse() ok = false, want true")
}
if string(translated) != "response-translated" {
t.Fatalf("TranslateResponse() = %q, want %q", translated, "response-translated")
}
}
func TestUsageAdapterPanicFusesPlugin(t *testing.T) {
host := newHostWithRecords(capabilityRecord{
id: "usage-panic",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
UsagePlugin: usagePluginFunc(func(ctx context.Context, record pluginapi.UsageRecord) {
panic("usage panic")
}),
}},
})
adapter := &usageAdapter{
host: host,
pluginID: "usage-panic",
}
adapter.HandleUsage(context.Background(), coreusage.Record{Provider: "plugin-provider"})
if !host.isPluginFused("usage-panic") {
t.Fatal("usage-panic was not fused")
}
}
func TestUsageManagerRegisterNamedReplacesWithoutDuplicateDispatch(t *testing.T) {
manager := coreusage.NewManager(0)
defer manager.Stop()
calls := make(chan string, 2)
manager.RegisterNamed("plugin:alpha", coreUsagePluginFunc(func(ctx context.Context, record coreusage.Record) {
calls <- "first"
}))
manager.RegisterNamed("plugin:alpha", coreUsagePluginFunc(func(ctx context.Context, record coreusage.Record) {
calls <- "second"
}))
manager.Publish(context.Background(), coreusage.Record{Provider: "provider"})
select {
case got := <-calls:
if got != "second" {
t.Fatalf("first dispatch = %q, want second", got)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timed out waiting for usage dispatch")
}
select {
case got := <-calls:
t.Fatalf("unexpected duplicate dispatch from %q", got)
case <-time.After(50 * time.Millisecond):
}
}
func TestRegisterFrontendAuthProvidersPrunesStaleKeys(t *testing.T) {
const key = "plugin:auth-active:custom-auth"
sdkaccess.UnregisterProvider(key)
defer sdkaccess.UnregisterProvider(key)
host := newHostWithRecords(capabilityRecord{
id: "auth-active",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{
identifier: "custom-auth",
authenticate: func(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
return pluginapi.FrontendAuthResponse{Authenticated: true}, nil
},
},
}},
})
host.RegisterFrontendAuthProviders()
if !registeredProviderIdentifier(key) {
t.Fatalf("registered providers did not include %q", key)
}
host.snapshot.Store(&Snapshot{enabled: true})
host.RegisterFrontendAuthProviders()
if registeredProviderIdentifier(key) {
t.Fatalf("registered providers still included stale key %q", key)
}
}
func TestRegisterFrontendAuthProvidersIdentifierPanicFusesPlugin(t *testing.T) {
host := newHostWithRecords(capabilityRecord{
id: "auth-identifier-panic",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: panicFrontendAuthProvider{},
}},
})
host.RegisterFrontendAuthProviders()
if !host.isPluginFused("auth-identifier-panic") {
t.Fatal("auth-identifier-panic was not fused")
}
}
func TestRegisterFrontendAuthProvidersSelectsHighestPriorityExclusiveProvider(t *testing.T) {
lowKey := "plugin:exclusive-low:custom-auth"
highKey := "plugin:exclusive-high:custom-auth"
normalKey := "plugin:normal-auth:custom-auth"
for _, key := range []string{lowKey, highKey, normalKey} {
sdkaccess.UnregisterProvider(key)
defer sdkaccess.UnregisterProvider(key)
}
sdkaccess.ClearExclusiveProvider()
defer sdkaccess.ClearExclusiveProvider()
host := newHostWithRecords(
capabilityRecord{
id: "exclusive-low",
priority: 1,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
FrontendAuthProviderExclusive: true,
}},
},
capabilityRecord{
id: "exclusive-high",
priority: 10,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
FrontendAuthProviderExclusive: true,
}},
},
capabilityRecord{
id: "normal-auth",
priority: 20,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
}},
},
)
host.RegisterFrontendAuthProviders()
providers := sdkaccess.RegisteredProviders()
if len(providers) != 1 {
t.Fatalf("RegisteredProviders() len = %d, want 1", len(providers))
}
if providers[0].Identifier() != highKey {
t.Fatalf("exclusive provider = %q, want %q", providers[0].Identifier(), highKey)
}
}
func TestRegisterFrontendAuthProvidersSelectsExclusiveProviderByPluginIDWhenPriorityTies(t *testing.T) {
alphaKey := "plugin:alpha-auth:custom-auth"
betaKey := "plugin:beta-auth:custom-auth"
for _, key := range []string{alphaKey, betaKey} {
sdkaccess.UnregisterProvider(key)
defer sdkaccess.UnregisterProvider(key)
}
sdkaccess.ClearExclusiveProvider()
defer sdkaccess.ClearExclusiveProvider()
host := newHostWithRecords(
capabilityRecord{
id: "beta-auth",
priority: 5,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
FrontendAuthProviderExclusive: true,
}},
},
capabilityRecord{
id: "alpha-auth",
priority: 5,
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
FrontendAuthProviderExclusive: true,
}},
},
)
host.RegisterFrontendAuthProviders()
providers := sdkaccess.RegisteredProviders()
if len(providers) != 1 {
t.Fatalf("RegisteredProviders() len = %d, want 1", len(providers))
}
if providers[0].Identifier() != alphaKey {
t.Fatalf("exclusive provider = %q, want %q", providers[0].Identifier(), alphaKey)
}
}
func TestRegisterFrontendAuthProvidersClearsExclusiveProviderWhenExclusivePluginRemoved(t *testing.T) {
exclusiveKey := "plugin:exclusive-auth:custom-auth"
normalKey := "plugin:normal-auth:custom-auth"
for _, key := range []string{exclusiveKey, normalKey} {
sdkaccess.UnregisterProvider(key)
defer sdkaccess.UnregisterProvider(key)
}
sdkaccess.ClearExclusiveProvider()
defer sdkaccess.ClearExclusiveProvider()
host := newHostWithRecords(
capabilityRecord{
id: "exclusive-auth",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
FrontendAuthProviderExclusive: true,
}},
},
capabilityRecord{
id: "normal-auth",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
}},
},
)
host.RegisterFrontendAuthProviders()
if got := sdkaccess.RegisteredProviders(); len(got) != 1 || got[0].Identifier() != exclusiveKey {
t.Fatalf("exclusive RegisteredProviders() = %#v, want only %q", got, exclusiveKey)
}
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{
{
id: "normal-auth",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
}},
},
}})
host.RegisterFrontendAuthProviders()
providers := sdkaccess.RegisteredProviders()
if len(providers) != 1 {
t.Fatalf("RegisteredProviders() len = %d, want 1", len(providers))
}
if providers[0].Identifier() != normalKey {
t.Fatalf("restored provider = %q, want %q", providers[0].Identifier(), normalKey)
}
}
func TestRegisterFrontendAuthProvidersIgnoresExclusiveWithoutFrontendAuthProvider(t *testing.T) {
normalKey := "plugin:normal-auth:custom-auth"
sdkaccess.UnregisterProvider(normalKey)
sdkaccess.ClearExclusiveProvider()
defer sdkaccess.UnregisterProvider(normalKey)
defer sdkaccess.ClearExclusiveProvider()
host := newHostWithRecords(
capabilityRecord{
id: "exclusive-without-provider",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProviderExclusive: true,
}},
},
capabilityRecord{
id: "normal-auth",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
FrontendAuthProvider: frontendAuthProviderFunc{identifier: "custom-auth"},
}},
},
)
host.RegisterFrontendAuthProviders()
providers := sdkaccess.RegisteredProviders()
if len(providers) != 1 {
t.Fatalf("RegisteredProviders() len = %d, want 1", len(providers))
}
if providers[0].Identifier() != normalKey {
t.Fatalf("provider = %q, want %q", providers[0].Identifier(), normalKey)
}
}
func TestUsageAdapterUsesCurrentSnapshotCapability(t *testing.T) {
oldCalls := 0
newCalls := 0
oldPlugin := usagePluginFunc(func(ctx context.Context, record pluginapi.UsageRecord) {
oldCalls++
})
newPlugin := usagePluginFunc(func(ctx context.Context, record pluginapi.UsageRecord) {
newCalls++
})
host := newHostWithRecords(capabilityRecord{
id: "usage-active",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
UsagePlugin: oldPlugin,
}},
})
adapter := &usageAdapter{
host: host,
pluginID: "usage-active",
plugin: oldPlugin,
}
host.snapshot.Store(&Snapshot{enabled: true, records: []capabilityRecord{{
id: "usage-active",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
UsagePlugin: newPlugin,
}},
}}})
adapter.HandleUsage(context.Background(), coreusage.Record{Provider: "provider"})
if oldCalls != 0 {
t.Fatalf("old usage plugin calls = %d, want 0", oldCalls)
}
if newCalls != 1 {
t.Fatalf("new usage plugin calls = %d, want 1", newCalls)
}
}
func TestRegisterUsagePluginsStaleAdapterSkipsRemovedCapability(t *testing.T) {
calls := 0
plugin := usagePluginFunc(func(ctx context.Context, record pluginapi.UsageRecord) {
calls++
})
host := newHostWithRecords(capabilityRecord{
id: "usage-active",
plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{
UsagePlugin: plugin,
}},
})
host.RegisterUsagePlugins()
adapter := &usageAdapter{
host: host,
pluginID: "usage-active",
plugin: plugin,
}
host.snapshot.Store(&Snapshot{enabled: true})
adapter.HandleUsage(context.Background(), coreusage.Record{Provider: "provider"})
if calls != 0 {
t.Fatalf("usage plugin calls = %d, want 0 after capability removal", calls)
}
}
func TestAccessAdapterUnauthenticatedReturnsNotHandled(t *testing.T) {
host := New()
adapter := &accessAdapter{
host: host,
pluginID: "auth-plugin",
provider: frontendAuthProviderFunc{
identifier: "custom-auth",
authenticate: func(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
return pluginapi.FrontendAuthResponse{Authenticated: false}, nil
},
},
}
req, errNewRequest := http.NewRequest(http.MethodGet, "http://example.test/v1/models", nil)
if errNewRequest != nil {
t.Fatalf("NewRequest() error = %v", errNewRequest)
}
result, authErr := adapter.Authenticate(context.Background(), req)
if result != nil {
t.Fatalf("Authenticate() result = %#v, want nil", result)
}
if !sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNotHandled) {
t.Fatalf("Authenticate() error = %v, want not handled", authErr)
}
}
func TestAccessAdapterPanicFusesAndReturnsNotHandled(t *testing.T) {
host := New()
adapter := &accessAdapter{
host: host,
pluginID: "auth-panic",
provider: frontendAuthProviderFunc{
identifier: "custom-auth",
authenticate: func(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
panic("auth panic")
},
},
}
req, errNewRequest := http.NewRequest(http.MethodGet, "http://example.test/v1/models", nil)
if errNewRequest != nil {
t.Fatalf("NewRequest() error = %v", errNewRequest)
}
result, authErr := adapter.Authenticate(context.Background(), req)
if result != nil {
t.Fatalf("Authenticate() result = %#v, want nil", result)
}
if !sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNotHandled) {
t.Fatalf("Authenticate() error = %v, want not handled", authErr)
}
if !host.isPluginFused("auth-panic") {
t.Fatal("auth-panic was not fused")
}
}
func TestAccessAdapterBodyReadFailureReturnsInternalError(t *testing.T) {
host := New()
called := false
adapter := &accessAdapter{
host: host,
pluginID: "auth-plugin",
provider: frontendAuthProviderFunc{
identifier: "custom-auth",
authenticate: func(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
called = true
return pluginapi.FrontendAuthResponse{Authenticated: true}, nil
},
},
}
req, errNewRequest := http.NewRequest(http.MethodPost, "http://example.test/v1/chat", nil)
if errNewRequest != nil {
t.Fatalf("NewRequest() error = %v", errNewRequest)
}
req.Body = failingReadCloser{}
result, authErr := adapter.Authenticate(context.Background(), req)
if result != nil {
t.Fatalf("Authenticate() result = %#v, want nil", result)
}
if !sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInternal) {
t.Fatalf("Authenticate() error = %v, want internal auth error", authErr)
}
if called {
t.Fatal("plugin provider was called after body read failure")
}
}
func TestAccessAdapterErrorReturnsNotHandledAndRestoresBody(t *testing.T) {
host := New()
adapter := &accessAdapter{
host: host,
pluginID: "auth-plugin",
provider: frontendAuthProviderFunc{
identifier: "custom-auth",
authenticate: func(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
if string(req.Body) != "request-body" {
t.Fatalf("plugin request body = %q, want %q", req.Body, "request-body")
}
return pluginapi.FrontendAuthResponse{}, fmt.Errorf("not mine")
},
},
}
req, errNewRequest := http.NewRequest(http.MethodPost, "http://example.test/v1/chat?x=1", bytes.NewBufferString("request-body"))
if errNewRequest != nil {
t.Fatalf("NewRequest() error = %v", errNewRequest)
}
result, authErr := adapter.Authenticate(context.Background(), req)
if result != nil {
t.Fatalf("Authenticate() result = %#v, want nil", result)
}
if !sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNotHandled) {
t.Fatalf("Authenticate() error = %v, want not handled", authErr)
}
restored, errReadAll := io.ReadAll(req.Body)
if errReadAll != nil {
t.Fatalf("ReadAll(restored body) error = %v", errReadAll)
}
if string(restored) != "request-body" {
t.Fatalf("restored body = %q, want %q", restored, "request-body")
}
}
func TestExecutorAdapterMethods(t *testing.T) {
streamChunks := make(chan pluginapi.ExecutorStreamChunk, 2)
streamErr := errors.New("stream failed")
streamChunks <- pluginapi.ExecutorStreamChunk{Payload: []byte("stream-1")}
streamChunks <- pluginapi.ExecutorStreamChunk{Err: streamErr}
close(streamChunks)
pluginHTTPBody := []byte("http-response")
pluginHTTPHeaders := http.Header{"X-Http": []string{"1"}}
authProvider := fakeAuthProvider{
identifier: "plugin-provider",
refreshAuth: func(ctx context.Context, req pluginapi.AuthRefreshRequest) (pluginapi.AuthRefreshResponse, error) {
if req.AuthID != "auth-1" || req.AuthProvider != "plugin-provider" || req.Metadata["old"] != "value" {
t.Fatalf("refresh request = %#v, want auth metadata", req)
}
if req.HTTPClient == nil {
t.Fatal("refresh request HTTPClient = nil, want host HTTP bridge")
}
return pluginapi.AuthRefreshResponse{
Auth: pluginapi.AuthData{
Metadata: map[string]any{"token": "new"},
},
}, nil
},
}
host := newHostWithRecords(capabilityRecord{
id: "auth-plugin",
plugin: pluginapi.Plugin{
Capabilities: pluginapi.Capabilities{
AuthProvider: authProvider,
},
},
})
exec := &fakeExecutor{
identifier: "ignored-by-adapter",
execute: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
assertExecutorRequest(t, req)
return pluginapi.ExecutorResponse{
Payload: []byte("execute-response"),
Headers: http.Header{"X-Execute": []string{"1"}},
Metadata: map[string]any{
"phase": "execute",
},
}, nil
},
executeStream: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorStreamResponse, error) {
assertExecutorRequest(t, req)
return pluginapi.ExecutorStreamResponse{
Headers: http.Header{"X-Stream": []string{"1"}},
Chunks: streamChunks,
}, nil
},
countTokens: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
assertExecutorRequest(t, req)
return pluginapi.ExecutorResponse{Payload: []byte(`{"total_tokens":3}`)}, nil
},
httpRequest: func(ctx context.Context, req pluginapi.ExecutorHTTPRequest) (pluginapi.ExecutorHTTPResponse, error) {
if req.AuthID != "auth-1" || req.AuthProvider != "plugin-provider" || req.Method != http.MethodPatch ||
req.URL != "http://example.test/v1/raw?x=1" || req.Headers.Get("X-Raw") != "yes" || string(req.Body) != "raw-body" {
t.Fatalf("http request = %#v, want mapped raw HTTP request", req)
}
if req.HTTPClient == nil {
t.Fatal("http request HTTPClient = nil, want host HTTP bridge")
}
return pluginapi.ExecutorHTTPResponse{
StatusCode: http.StatusAccepted,
Headers: pluginHTTPHeaders,
Body: pluginHTTPBody,
}, nil
},
}
adapter := &executorAdapter{
host: host,
pluginID: "executor-plugin",
provider: "plugin-provider",
executor: exec,
inputFormats: []sdktranslator.Format{sdktranslator.FormatOpenAI},
outputFormats: []sdktranslator.Format{sdktranslator.FormatOpenAI},
}
auth := &coreauth.Auth{
ID: "auth-1",
Provider: "plugin-provider",
Metadata: map[string]any{"old": "value"},
}
req := coreexecutor.Request{
Model: "model-1",
Format: sdktranslator.FormatOpenAI,
Payload: []byte("payload"),
Metadata: map[string]any{
"req": "metadata",
},
}
opts := coreexecutor.Options{
Stream: true,
Alt: "alt",
Headers: http.Header{"X-Request": []string{"yes"}},
OriginalRequest: []byte("original"),
SourceFormat: sdktranslator.FormatOpenAI,
Metadata: map[string]any{
"opt": "metadata",
},
}
if adapter.Identifier() != "plugin-provider" {
t.Fatalf("Identifier() = %q, want %q", adapter.Identifier(), "plugin-provider")
}
resp, errExecute := adapter.Execute(context.Background(), auth, req, opts)
if errExecute != nil {
t.Fatalf("Execute() error = %v", errExecute)
}
if string(resp.Payload) != "execute-response" || resp.Headers.Get("X-Execute") != "1" || resp.Metadata["phase"] != "execute" {
t.Fatalf("Execute() = %#v, want mapped response", resp)
}
stream, errExecuteStream := adapter.ExecuteStream(context.Background(), auth, req, opts)
if errExecuteStream != nil {
t.Fatalf("ExecuteStream() error = %v", errExecuteStream)
}
if stream.Headers.Get("X-Stream") != "1" {
t.Fatalf("ExecuteStream() headers = %#v, want X-Stream", stream.Headers)
}
first := <-stream.Chunks
if string(first.Payload) != "stream-1" || first.Err != nil {
t.Fatalf("first stream chunk = %#v, want payload chunk", first)
}
second := <-stream.Chunks
if second.Err != streamErr {
t.Fatalf("second stream chunk err = %v, want %v", second.Err, streamErr)
}
if _, ok := <-stream.Chunks; ok {
t.Fatal("stream chunks channel still open, want closed")
}
refreshed, errRefresh := adapter.Refresh(context.Background(), auth)
if errRefresh != nil {
t.Fatalf("Refresh() error = %v", errRefresh)
}
if refreshed == auth {
t.Fatal("Refresh() returned original auth pointer, want clone")
}
if refreshed.Metadata["token"] != "new" {
t.Fatalf("Refresh() metadata = %#v, want token=new", refreshed.Metadata)
}
count, errCountTokens := adapter.CountTokens(context.Background(), auth, req, opts)
if errCountTokens != nil {
t.Fatalf("CountTokens() error = %v", errCountTokens)
}
if string(count.Payload) != `{"total_tokens":3}` {
t.Fatalf("CountTokens() payload = %q, want token payload", count.Payload)
}
rawReq, errNewRawRequest := http.NewRequest(http.MethodPatch, "http://example.test/v1/raw?x=1", bytes.NewBufferString("raw-body"))
if errNewRawRequest != nil {
t.Fatalf("NewRequest(raw) error = %v", errNewRawRequest)
}
rawReq.Header.Set("X-Raw", "yes")
httpResp, errHTTPRequest := adapter.HttpRequest(context.Background(), auth, rawReq)
if errHTTPRequest != nil {
t.Fatalf("HttpRequest() error = %v", errHTTPRequest)
}
if httpResp.StatusCode != http.StatusAccepted || httpResp.Status != "202 Accepted" || httpResp.Header.Get("X-Http") != "1" {
t.Fatalf("HttpRequest() response = %#v, want mapped status/header", httpResp)
}
pluginHTTPBody[0] = 'X'
pluginHTTPHeaders.Set("X-Http", "mutated")
body, errReadBody := io.ReadAll(httpResp.Body)
if errReadBody != nil {
t.Fatalf("ReadAll(HttpRequest body) error = %v", errReadBody)
}
if string(body) != "http-response" || httpResp.Header.Get("X-Http") != "1" {
t.Fatalf("HttpRequest() response aliases plugin data: body=%q header=%q", body, httpResp.Header.Get("X-Http"))
}
restoredRawBody, errReadRawBody := io.ReadAll(rawReq.Body)
if errReadRawBody != nil {
t.Fatalf("ReadAll(restored raw request body) error = %v", errReadRawBody)
}
if string(restoredRawBody) != "raw-body" {
t.Fatalf("restored raw request body = %q, want raw-body", restoredRawBody)
}
nilResp, errNilRequest := adapter.HttpRequest(context.Background(), auth, nil)
if nilResp != nil {
t.Fatalf("HttpRequest(nil) response = %#v, want nil", nilResp)
}
if errNilRequest == nil || !strings.Contains(errNilRequest.Error(), "nil HTTP request") {
t.Fatalf("HttpRequest(nil) error = %v, want nil request error", errNilRequest)
}
}
func TestExecutorAdapterConsumesTranslatedStreamChunksWithoutOutput(t *testing.T) {
adapter := &executorAdapter{}
request := []byte(`{"model":"qmodel_latest","stream":true,"tool_choice":"auto","parallel_tool_calls":true}`)
prepared := preparedExecutorCall{
req: coreexecutor.Request{
Model: "qmodel_latest",
Payload: request,
},
opts: coreexecutor.Options{
OriginalRequest: request,
},
requestedFormat: sdktranslator.FormatOpenAIResponse,
outputFormat: sdktranslator.FormatOpenAI,
}
var param any
startPayload := []byte(`{"choices":[{"delta":{"content":"","tool_calls":[{"function":{"arguments":"","name":"get_weather"},"id":"call_69755759d70640e3b7a42805","index":0,"type":"function"}]},"index":0}],"created":1780767281,"id":"chatcmpl-ba492ed2-2901-9d1f-80e7-b6dfe97fefaa","model":"auto","object":"chat.completion.chunk"}`)
if got := adapter.translateExecutorStreamPayload(context.Background(), prepared, startPayload, &param); len(got) == 0 {
t.Fatal("tool call start payload was not translated")
}
emptyArgumentsPayload := []byte(`{"choices":[{"delta":{"content":"","tool_calls":[{"function":{"arguments":""},"id":"","index":0,"type":"function"}]},"index":0}],"created":1780767281,"id":"chatcmpl-ba492ed2-2901-9d1f-80e7-b6dfe97fefaa","model":"auto","object":"chat.completion.chunk"}`)
if got := adapter.translateExecutorStreamPayload(context.Background(), prepared, emptyArgumentsPayload, &param); len(got) != 0 {
t.Fatalf("empty arguments payload leaked through translation fallback: %q", got[0])
}
finishPayload := []byte(`{"choices":[{"delta":{},"finish_reason":"tool_calls","index":0}],"created":1780767281,"id":"chatcmpl-ba492ed2-2901-9d1f-80e7-b6dfe97fefaa","model":"auto","object":"chat.completion.chunk"}`)
if got := adapter.translateExecutorStreamPayload(context.Background(), prepared, finishPayload, &param); len(got) == 0 {
t.Fatal("finish payload was not translated")
}
usagePayload := []byte(`{"choices":[],"created":1780767281,"id":"chatcmpl-ba492ed2-2901-9d1f-80e7-b6dfe97fefaa","model":"auto","object":"chat.completion.chunk","usage":{"completion_tokens":179,"completion_tokens_details":{"reasoning_tokens":121},"prompt_tokens":331,"prompt_tokens_details":{"cached_tokens":0},"total_tokens":510}}`)
if got := adapter.translateExecutorStreamPayload(context.Background(), prepared, usagePayload, &param); len(got) != 0 {
t.Fatalf("usage-only payload leaked through translation fallback: %q", got[0])
}
donePayload := []byte(`data: [DONE]`)
doneFrames := adapter.translateExecutorStreamPayload(context.Background(), prepared, donePayload, &param)
if len(doneFrames) != 1 {
t.Fatalf("done payload translated to %d frames, want 1", len(doneFrames))
}
if !bytes.Contains(doneFrames[0], []byte("response.completed")) {
t.Fatalf("done payload did not produce response.completed: %q", doneFrames[0])
}
if !bytes.Contains(doneFrames[0], []byte(`"input_tokens":331`)) ||
!bytes.Contains(doneFrames[0], []byte(`"output_tokens":179`)) ||
!bytes.Contains(doneFrames[0], []byte(`"reasoning_tokens":121`)) ||
!bytes.Contains(doneFrames[0], []byte(`"total_tokens":510`)) {
t.Fatalf("completed payload did not preserve usage: %q", doneFrames[0])
}
}
func TestExecutorAdapterPanicFusesAndReturnsError(t *testing.T) {
host := New()
calls := 0
adapter := &executorAdapter{
host: host,
pluginID: "executor-panic",
provider: "plugin-provider",
inputFormats: []sdktranslator.Format{sdktranslator.FormatOpenAI},
outputFormats: []sdktranslator.Format{sdktranslator.FormatOpenAI},
executor: &fakeExecutor{
execute: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
calls++
panic("execute panic")
},
countTokens: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
calls++
return pluginapi.ExecutorResponse{Payload: []byte("should-not-run")}, nil
},
},
}
resp, errExecute := adapter.Execute(context.Background(), &coreauth.Auth{}, coreexecutor.Request{}, coreexecutor.Options{})
if errExecute == nil {
t.Fatal("Execute() error = nil, want panic converted to error")
}
if len(resp.Payload) != 0 {
t.Fatalf("Execute() response = %#v, want zero response", resp)
}
if !host.isPluginFused("executor-panic") {
t.Fatal("executor-panic was not fused")
}
if calls != 1 {
t.Fatalf("plugin calls after first Execute() = %d, want 1", calls)
}
count, errCountTokens := adapter.CountTokens(context.Background(), &coreauth.Auth{}, coreexecutor.Request{}, coreexecutor.Options{})
if errCountTokens == nil {
t.Fatal("CountTokens() error after fuse = nil, want unavailable error")
}
if len(count.Payload) != 0 {
t.Fatalf("CountTokens() response after fuse = %#v, want zero response", count)
}
if calls != 1 {
t.Fatalf("plugin calls after fused CountTokens() = %d, want 1", calls)
}
}
func TestMapExecutorStreamChunksExitsWhenContextCanceledWithoutDownstreamConsumer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
in := make(chan pluginapi.ExecutorStreamChunk)
out := mapExecutorStreamChunks(ctx, in)
sent := make(chan struct{})
go func() {
in <- pluginapi.ExecutorStreamChunk{Payload: []byte("chunk")}
close(sent)
}()
select {
case <-sent:
case <-time.After(100 * time.Millisecond):
t.Fatal("input chunk was not accepted by bridge")
}
cancel()
time.Sleep(10 * time.Millisecond)
select {
case chunk, ok := <-out:
if ok {
t.Fatalf("output channel produced chunk after cancel: %#v", chunk)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("output channel was not closed after context cancellation")
}
}
func newHostWithRecords(records ...capabilityRecord) *Host {
host := New()
sortRecords(records)
host.snapshot.Store(&Snapshot{enabled: true, records: records})
return host
}
type stringSliceAlias []string
type mapSliceAlias []map[string]string
type requestNormalizerFunc func(context.Context, pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error)
func (f requestNormalizerFunc) NormalizeRequest(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return f(ctx, req)
}
type requestTranslatorFunc func(context.Context, pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error)
func (f requestTranslatorFunc) TranslateRequest(ctx context.Context, req pluginapi.RequestTransformRequest) (pluginapi.PayloadResponse, error) {
return f(ctx, req)
}
type responseNormalizerFunc func(context.Context, pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error)
func (f responseNormalizerFunc) NormalizeResponse(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return f(ctx, req)
}
type responseTranslatorFunc func(context.Context, pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error)
func (f responseTranslatorFunc) TranslateResponse(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
return f(ctx, req)
}
type usagePluginFunc func(context.Context, pluginapi.UsageRecord)
func (f usagePluginFunc) HandleUsage(ctx context.Context, record pluginapi.UsageRecord) {
f(ctx, record)
}
type coreUsagePluginFunc func(context.Context, coreusage.Record)
func (f coreUsagePluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) {
f(ctx, record)
}
type frontendAuthProviderFunc struct {
identifier string
authenticate func(context.Context, pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error)
}
func (f frontendAuthProviderFunc) Identifier() string {
return f.identifier
}
func (f frontendAuthProviderFunc) Authenticate(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
return f.authenticate(ctx, req)
}
type panicFrontendAuthProvider struct{}
func (panicFrontendAuthProvider) Identifier() string {
panic("identifier panic")
}
func (panicFrontendAuthProvider) Authenticate(ctx context.Context, req pluginapi.FrontendAuthRequest) (pluginapi.FrontendAuthResponse, error) {
return pluginapi.FrontendAuthResponse{}, nil
}
type fakeAuthProvider struct {
identifier string
parseAuth func(context.Context, pluginapi.AuthParseRequest) (pluginapi.AuthParseResponse, error)
startLogin func(context.Context, pluginapi.AuthLoginStartRequest) (pluginapi.AuthLoginStartResponse, error)
pollLogin func(context.Context, pluginapi.AuthLoginPollRequest) (pluginapi.AuthLoginPollResponse, error)
refreshAuth func(context.Context, pluginapi.AuthRefreshRequest) (pluginapi.AuthRefreshResponse, error)
}
func (p fakeAuthProvider) Identifier() string {
return p.identifier
}
func (p fakeAuthProvider) ParseAuth(ctx context.Context, req pluginapi.AuthParseRequest) (pluginapi.AuthParseResponse, error) {
if p.parseAuth == nil {
return pluginapi.AuthParseResponse{}, nil
}
return p.parseAuth(ctx, req)
}
func (p fakeAuthProvider) StartLogin(ctx context.Context, req pluginapi.AuthLoginStartRequest) (pluginapi.AuthLoginStartResponse, error) {
if p.startLogin == nil {
return pluginapi.AuthLoginStartResponse{}, nil
}
return p.startLogin(ctx, req)
}
func (p fakeAuthProvider) PollLogin(ctx context.Context, req pluginapi.AuthLoginPollRequest) (pluginapi.AuthLoginPollResponse, error) {
if p.pollLogin == nil {
return pluginapi.AuthLoginPollResponse{}, nil
}
return p.pollLogin(ctx, req)
}
func (p fakeAuthProvider) RefreshAuth(ctx context.Context, req pluginapi.AuthRefreshRequest) (pluginapi.AuthRefreshResponse, error) {
if p.refreshAuth == nil {
return pluginapi.AuthRefreshResponse{}, nil
}
return p.refreshAuth(ctx, req)
}
type modelRegistrarFunc func(context.Context, pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error)
func (f modelRegistrarFunc) RegisterModels(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return f(ctx, req)
}
type modelProviderFunc struct {
staticModels func(context.Context, pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error)
modelsForAuth func(context.Context, pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error)
}
func (f modelProviderFunc) StaticModels(ctx context.Context, req pluginapi.StaticModelRequest) (pluginapi.ModelResponse, error) {
if f.staticModels == nil {
return pluginapi.ModelResponse{}, nil
}
return f.staticModels(ctx, req)
}
func (f modelProviderFunc) ModelsForAuth(ctx context.Context, req pluginapi.AuthModelRequest) (pluginapi.ModelResponse, error) {
if f.modelsForAuth == nil {
return pluginapi.ModelResponse{}, nil
}
return f.modelsForAuth(ctx, req)
}
func staticModelRegistrar(provider, modelID string) pluginapi.ModelRegistrar {
return modelRegistrarFunc(func(ctx context.Context, req pluginapi.ModelRegistrationRequest) (pluginapi.ModelRegistrationResponse, error) {
return pluginapi.ModelRegistrationResponse{
Provider: provider,
Models: []pluginapi.ModelInfo{{
ID: modelID,
}},
}, nil
})
}
func registeredProviderIdentifier(identifier string) bool {
for _, provider := range sdkaccess.RegisteredProviders() {
if provider != nil && provider.Identifier() == identifier {
return true
}
}
return false
}
type fakeModelRegistry struct {
clients map[string]*fakeModelClient
unregisters []string
}
type fakeModelClient struct {
provider string
models []*registry.ModelInfo
}
func newFakeModelRegistry() *fakeModelRegistry {
return &fakeModelRegistry{
clients: make(map[string]*fakeModelClient),
}
}
func (r *fakeModelRegistry) RegisterClient(clientID, clientProvider string, models []*registry.ModelInfo) {
r.clients[clientID] = &fakeModelClient{
provider: clientProvider,
models: models,
}
}
func (r *fakeModelRegistry) UnregisterClient(clientID string) {
delete(r.clients, clientID)
r.unregisters = append(r.unregisters, clientID)
}
func (r *fakeModelRegistry) GetModelProviders(modelID string) []string {
counts := make(map[string]int)
for _, client := range r.clients {
if client == nil || client.provider == "" {
continue
}
for _, model := range client.models {
if model != nil && model.ID == modelID {
counts[client.provider]++
}
}
}
providers := make([]string, 0, len(counts))
for provider := range counts {
providers = append(providers, provider)
}
sort.Strings(providers)
return providers
}
type fakeExecutorManager struct {
executors map[string]coreauth.ProviderExecutor
registerCalls int
unregisters []string
}
func newFakeExecutorManager() *fakeExecutorManager {
return &fakeExecutorManager{
executors: make(map[string]coreauth.ProviderExecutor),
}
}
func (m *fakeExecutorManager) Executor(provider string) (coreauth.ProviderExecutor, bool) {
executor, okExecutor := m.executors[provider]
return executor, okExecutor
}
func (m *fakeExecutorManager) RegisterExecutor(executor coreauth.ProviderExecutor) {
m.registerCalls++
m.executors[executor.Identifier()] = executor
}
func (m *fakeExecutorManager) UnregisterExecutor(provider string) {
delete(m.executors, provider)
m.unregisters = append(m.unregisters, provider)
}
type fakeProviderExecutor struct {
provider string
}
func (e *fakeProviderExecutor) Identifier() string {
return e.provider
}
func (e *fakeProviderExecutor) Execute(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, nil
}
func (e *fakeProviderExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
return nil, nil
}
func (e *fakeProviderExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *fakeProviderExecutor) CountTokens(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, nil
}
func (e *fakeProviderExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type fakeExecutor struct {
identifier string
identifierFunc func() string
panicIdentifier bool
execute func(context.Context, pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error)
executeStream func(context.Context, pluginapi.ExecutorRequest) (pluginapi.ExecutorStreamResponse, error)
countTokens func(context.Context, pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error)
httpRequest func(context.Context, pluginapi.ExecutorHTTPRequest) (pluginapi.ExecutorHTTPResponse, error)
}
func (e *fakeExecutor) Identifier() string {
if e.panicIdentifier {
panic("identifier panic")
}
if e.identifierFunc != nil {
return e.identifierFunc()
}
return e.identifier
}
func (e *fakeExecutor) Execute(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
return e.execute(ctx, req)
}
func (e *fakeExecutor) ExecuteStream(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorStreamResponse, error) {
return e.executeStream(ctx, req)
}
func (e *fakeExecutor) CountTokens(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) {
return e.countTokens(ctx, req)
}
func (e *fakeExecutor) HttpRequest(ctx context.Context, req pluginapi.ExecutorHTTPRequest) (pluginapi.ExecutorHTTPResponse, error) {
if e.httpRequest == nil {
return pluginapi.ExecutorHTTPResponse{}, nil
}
return e.httpRequest(ctx, req)
}
func assertExecutorRequest(t *testing.T, req pluginapi.ExecutorRequest) {
t.Helper()
if req.AuthID != "auth-1" || req.AuthProvider != "plugin-provider" || req.Model != "model-1" || req.Format != sdktranslator.FormatOpenAI.String() ||
!req.Stream || req.Alt != "alt" || req.Headers.Get("X-Request") != "yes" || string(req.OriginalRequest) != "original" ||
req.SourceFormat != sdktranslator.FormatOpenAI.String() || string(req.Payload) != "payload" ||
req.Metadata["req"] != "metadata" || req.Metadata["opt"] != "metadata" {
t.Fatalf("executor request = %#v, want mapped request", req)
}
}
type failingReadCloser struct{}
func (failingReadCloser) Read(p []byte) (int, error) {
copy(p, []byte("partial"))
return len("partial"), errors.New("read failed")
}
func (failingReadCloser) Close() error {
return nil
}