Files
CLIProxyAPI/internal/pluginhost/callback_contexts.go
Luis Pater 538e3416db feat(plugin, api): prevent plugin recursion on host model callbacks, enable targeted interceptor skipping
- Updated host model callback logic to skip originating plugin's interceptors during nested model executions.
- Added `SkipInterceptorPluginID` field to plugin API structs for controlling interceptor bypass behavior.
- Introduced supporting logic in host API handlers, plugin host registry, and callback contexts to identify and skip specific plugins.
- Enhanced unit tests across plugin host, API handlers, and execution paths to verify interceptor skipping behavior and plugin isolation.
- Revised documentation to clarify non-recursive behavior of host model callbacks and the use of `SkipInterceptorPluginID`.
2026-06-12 02:38:51 +08:00

140 lines
3.0 KiB
Go

package pluginhost
import (
"context"
"strconv"
"strings"
"sync"
"sync/atomic"
)
type callbackContextRegistry struct {
next atomic.Uint64
mu sync.RWMutex
contexts map[string]callbackContextEntry
}
type callbackContextEntry struct {
ctx context.Context
pluginID string
cleanup []func()
}
func newCallbackContextRegistry() *callbackContextRegistry {
return &callbackContextRegistry{contexts: make(map[string]callbackContextEntry)}
}
func (r *callbackContextRegistry) open(ctx context.Context, pluginID string) (string, func()) {
if r == nil {
return "", func() {}
}
if ctx == nil {
ctx = context.Background()
}
pluginID = strings.TrimSpace(pluginID)
ctx = withHostCallbackPluginID(ctx, pluginID)
id := strconv.FormatUint(r.next.Add(1), 10)
r.mu.Lock()
r.contexts[id] = callbackContextEntry{ctx: ctx, pluginID: pluginID}
r.mu.Unlock()
var once sync.Once
return id, func() {
once.Do(func() {
var cleanup []func()
r.mu.Lock()
entry := r.contexts[id]
delete(r.contexts, id)
r.mu.Unlock()
cleanup = entry.cleanup
for _, fn := range cleanup {
if fn != nil {
fn()
}
}
})
}
}
func (r *callbackContextRegistry) pluginID(id string) string {
if r == nil || id == "" {
return ""
}
r.mu.RLock()
entry := r.contexts[id]
r.mu.RUnlock()
return strings.TrimSpace(entry.pluginID)
}
func (r *callbackContextRegistry) addCleanup(id string, cleanup func()) bool {
if r == nil || id == "" || cleanup == nil {
return false
}
r.mu.Lock()
entry, ok := r.contexts[id]
if ok {
entry.cleanup = append(entry.cleanup, cleanup)
r.contexts[id] = entry
}
r.mu.Unlock()
if !ok {
cleanup()
return false
}
return true
}
func (r *callbackContextRegistry) resolve(id string, fallback context.Context) context.Context {
if fallback == nil {
fallback = context.Background()
}
if r == nil || id == "" {
return fallback
}
r.mu.RLock()
ctx := r.contexts[id].ctx
r.mu.RUnlock()
if ctx == nil {
return fallback
}
return ctx
}
func (h *Host) openCallbackContext(ctx context.Context) (string, func()) {
return h.openCallbackContextForPlugin(ctx, "")
}
func (h *Host) openCallbackContextForPlugin(ctx context.Context, pluginID string) (string, func()) {
if h == nil || h.callbackContexts == nil {
return "", func() {}
}
return h.callbackContexts.open(ctx, pluginID)
}
func (h *Host) addCallbackCleanup(id string, cleanup func()) bool {
if h == nil || h.callbackContexts == nil {
if id != "" && cleanup != nil {
cleanup()
}
return false
}
return h.callbackContexts.addCleanup(id, cleanup)
}
func (h *Host) resolveCallbackContext(id string, fallback context.Context) context.Context {
if h == nil || h.callbackContexts == nil {
if fallback == nil {
return context.Background()
}
return fallback
}
return h.callbackContexts.resolve(id, fallback)
}
func (h *Host) callbackContextPluginID(id string) string {
if h == nil || h.callbackContexts == nil {
return ""
}
return h.callbackContexts.pluginID(id)
}