mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-13 10:12:50 +08:00
feat(pluginhost, jshandler): integrate HostCallbackID with interceptors and JS engine logging
- Added `HostCallbackID` to request, response, and stream chunk interceptors for enhanced context tracking. - Updated JavaScript engine to support custom console logging with `HostCallbackID` forwarding. - Introduced tests verifying proper integration of `HostCallbackID` in all interceptor flows and engine logging. - Enhanced logging and error handling for consistent callback-related logic implementation.
This commit is contained in:
@@ -92,6 +92,28 @@ type abiLifecycleRequest struct {
|
||||
PluginDir string `json:"plugin_dir,omitempty"`
|
||||
}
|
||||
|
||||
type abiRequestInterceptRequest struct {
|
||||
pluginapi.RequestInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiResponseInterceptRequest struct {
|
||||
pluginapi.ResponseInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiStreamChunkInterceptRequest struct {
|
||||
pluginapi.StreamChunkInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiHostLogRequest struct {
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
Level string `json:"level,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
type abiRegistration struct {
|
||||
SchemaVersion uint32 `json:"schema_version"`
|
||||
Metadata pluginapi.Metadata `json:"metadata"`
|
||||
@@ -192,25 +214,25 @@ func handleJSHandlerABIMethod(ctx context.Context, method string, request []byte
|
||||
defer done()
|
||||
switch method {
|
||||
case pluginabi.MethodRequestInterceptBefore:
|
||||
var req pluginapi.RequestInterceptRequest
|
||||
var req abiRequestInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptRequest(ctx, req)
|
||||
resp, errCall := p.interceptRequest(ctx, req.RequestInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
case pluginabi.MethodResponseInterceptAfter:
|
||||
var req pluginapi.ResponseInterceptRequest
|
||||
var req abiResponseInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptResponse(ctx, req)
|
||||
resp, errCall := p.interceptResponse(ctx, req.ResponseInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
case pluginabi.MethodResponseInterceptStreamChunk:
|
||||
var req pluginapi.StreamChunkInterceptRequest
|
||||
var req abiStreamChunkInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptStreamChunk(ctx, req)
|
||||
resp, errCall := p.interceptStreamChunk(ctx, req.StreamChunkInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
default:
|
||||
return abiErrorEnvelope("unknown_method", "unknown method: "+method), nil
|
||||
@@ -289,3 +311,85 @@ func writeABIResponse(response *C.cliproxy_buffer, raw []byte) {
|
||||
response.ptr = ptr
|
||||
response.len = C.size_t(len(raw))
|
||||
}
|
||||
|
||||
func newHostJSConsoleLogger(hostCallbackID string) jsConsoleLogger {
|
||||
return func(message string) error {
|
||||
if errLog := writeHostJSConsoleLog(hostCallbackID, message); errLog != nil {
|
||||
return defaultJSConsoleLogger(message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func writeHostJSConsoleLog(hostCallbackID string, message string) error {
|
||||
raw, errMarshal := json.Marshal(abiHostLogRequest{
|
||||
HostCallbackID: hostCallbackID,
|
||||
Level: "info",
|
||||
Message: "JS console log: " + message,
|
||||
Fields: map[string]any{
|
||||
"plugin_id": pluginName,
|
||||
},
|
||||
})
|
||||
if errMarshal != nil {
|
||||
return errMarshal
|
||||
}
|
||||
|
||||
rawResp, errCall := callHost(pluginabi.MethodHostLog, raw)
|
||||
if errCall != nil {
|
||||
return errCall
|
||||
}
|
||||
if len(rawResp) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp abiEnvelope
|
||||
if errDecode := json.Unmarshal(rawResp, &resp); errDecode != nil {
|
||||
return fmt.Errorf("decode host log response: %w", errDecode)
|
||||
}
|
||||
if !resp.OK {
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("host log failed: %s", resp.Error.Message)
|
||||
}
|
||||
return fmt.Errorf("host log failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func callHost(method string, payload []byte) ([]byte, error) {
|
||||
jsHandlerABIState.RLock()
|
||||
defer jsHandlerABIState.RUnlock()
|
||||
if jsHandlerABIState.host == nil {
|
||||
return nil, fmt.Errorf("host callback is unavailable")
|
||||
}
|
||||
|
||||
cMethod := C.CString(method)
|
||||
defer C.free(unsafe.Pointer(cMethod))
|
||||
|
||||
var cPayload unsafe.Pointer
|
||||
if len(payload) > 0 {
|
||||
cPayload = C.CBytes(payload)
|
||||
if cPayload == nil {
|
||||
return nil, fmt.Errorf("allocate host callback payload")
|
||||
}
|
||||
defer C.free(cPayload)
|
||||
}
|
||||
|
||||
var response C.cliproxy_buffer
|
||||
rc := C.jshandler_call_host(
|
||||
jsHandlerABIState.host,
|
||||
cMethod,
|
||||
(*C.uint8_t)(cPayload),
|
||||
C.size_t(len(payload)),
|
||||
&response,
|
||||
)
|
||||
var out []byte
|
||||
if response.ptr != nil && response.len > 0 {
|
||||
out = C.GoBytes(response.ptr, C.int(response.len))
|
||||
}
|
||||
if response.ptr != nil {
|
||||
C.jshandler_free_host_buffer(jsHandlerABIState.host, response.ptr, response.len)
|
||||
}
|
||||
if rc != 0 {
|
||||
return nil, fmt.Errorf("host callback %s returned %d: %s", method, int(rc), string(out))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,27 +14,43 @@ import (
|
||||
)
|
||||
|
||||
type jsEngine struct {
|
||||
vm *goja.Runtime
|
||||
vm *goja.Runtime
|
||||
consoleLogger jsConsoleLogger
|
||||
}
|
||||
|
||||
const maxJSScriptBytes = 8 * 1024 * 1024
|
||||
|
||||
func newJSEngine() *jsEngine {
|
||||
type jsConsoleLogger func(message string) error
|
||||
|
||||
func newJSEngine(loggers ...jsConsoleLogger) *jsEngine {
|
||||
consoleLogger := defaultJSConsoleLogger
|
||||
if len(loggers) > 0 && loggers[0] != nil {
|
||||
consoleLogger = loggers[0]
|
||||
}
|
||||
engine := &jsEngine{
|
||||
vm: goja.New(),
|
||||
vm: goja.New(),
|
||||
consoleLogger: consoleLogger,
|
||||
}
|
||||
engine.initConsole()
|
||||
return engine
|
||||
}
|
||||
|
||||
func defaultJSConsoleLogger(message string) error {
|
||||
log.Info("JS console log: ", message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (engine *jsEngine) initConsole() {
|
||||
console := engine.vm.NewObject()
|
||||
consoleLogWrapper := func(call goja.FunctionCall) goja.Value {
|
||||
args := make([]interface{}, len(call.Arguments))
|
||||
args := make([]string, len(call.Arguments))
|
||||
for i, arg := range call.Arguments {
|
||||
args[i] = arg.Export()
|
||||
args[i] = fmt.Sprint(arg.Export())
|
||||
}
|
||||
message := strings.Join(args, " ")
|
||||
if errLog := engine.consoleLogger(message); errLog != nil {
|
||||
defaultJSConsoleLogger(message)
|
||||
}
|
||||
log.Info("JS console log: ", fmt.Sprint(args...))
|
||||
return goja.Undefined()
|
||||
}
|
||||
_ = console.Set("log", consoleLogWrapper)
|
||||
|
||||
@@ -1,10 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestConsoleLogWritesToLogger(t *testing.T) {
|
||||
var out bytes.Buffer
|
||||
logger := log.StandardLogger()
|
||||
originalOut := logger.Out
|
||||
originalFormatter := logger.Formatter
|
||||
originalLevel := logger.Level
|
||||
log.SetOutput(&out)
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
DisableColors: true,
|
||||
DisableTimestamp: true,
|
||||
})
|
||||
log.SetLevel(log.InfoLevel)
|
||||
defer func() {
|
||||
log.SetOutput(originalOut)
|
||||
log.SetFormatter(originalFormatter)
|
||||
log.SetLevel(originalLevel)
|
||||
}()
|
||||
|
||||
engine := newJSEngine()
|
||||
_, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`)
|
||||
if errRun != nil {
|
||||
t.Fatalf("RunString() error = %v", errRun)
|
||||
}
|
||||
|
||||
got := out.String()
|
||||
if !strings.Contains(got, "JS console log: alpha 42 true") {
|
||||
t.Fatalf("console.log output = %q, want logger output with JS message", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsoleLogUsesConfiguredLogger(t *testing.T) {
|
||||
var messages []string
|
||||
engine := newJSEngine(func(message string) error {
|
||||
messages = append(messages, message)
|
||||
return nil
|
||||
})
|
||||
_, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`)
|
||||
if errRun != nil {
|
||||
t.Fatalf("RunString() error = %v", errRun)
|
||||
}
|
||||
if len(messages) != 1 || messages[0] != "alpha 42 true" {
|
||||
t.Fatalf("console log messages = %#v, want formatted message", messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInterruptTimerClearsExpiredInterrupt(t *testing.T) {
|
||||
engine := newJSEngine()
|
||||
timer, done := engine.startInterruptTimer(time.Nanosecond)
|
||||
|
||||
@@ -45,6 +45,10 @@ func (p *jsHandlerPlugin) allScriptPaths() []string {
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
|
||||
return p.interceptRequest(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest, hostCallbackID string) (pluginapi.RequestInterceptResponse, error) {
|
||||
resp := pluginapi.RequestInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -60,7 +64,7 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re
|
||||
if scriptPath == "" {
|
||||
continue
|
||||
}
|
||||
processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers)
|
||||
processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers, hostCallbackID)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS request interceptor [%s]: %v", scriptPath, errJS)
|
||||
continue
|
||||
@@ -78,6 +82,10 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
|
||||
return p.interceptResponse(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest, hostCallbackID string) (pluginapi.ResponseInterceptResponse, error) {
|
||||
resp := pluginapi.ResponseInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -98,6 +106,7 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R
|
||||
scriptPath, req.Model, req.SourceFormat,
|
||||
reqHeadersMap, req.RequestBody,
|
||||
bodyStr, nil, respHeaders, false, nil,
|
||||
hostCallbackID,
|
||||
)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS response interceptor [%s]: %v", scriptPath, errJS)
|
||||
@@ -121,6 +130,10 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
|
||||
return p.interceptStreamChunk(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest, hostCallbackID string) (pluginapi.StreamChunkInterceptResponse, error) {
|
||||
resp := pluginapi.StreamChunkInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -156,6 +169,7 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap
|
||||
scriptPath, req.Model, req.SourceFormat,
|
||||
reqHeadersMap, req.RequestBody,
|
||||
"", chunkPtr, respHeaders, !isHeaderInit, historyStrings,
|
||||
hostCallbackID,
|
||||
)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS stream chunk interceptor [%s]: %v", scriptPath, errJS)
|
||||
@@ -183,13 +197,13 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header) ([]byte, []string, error) {
|
||||
func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header, hostCallbackID string) ([]byte, []string, error) {
|
||||
program, err := getJSProgram(scriptPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
engine := newJSEngine()
|
||||
engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID))
|
||||
if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil {
|
||||
return nil, nil, errRun
|
||||
}
|
||||
@@ -246,13 +260,14 @@ func (p *jsHandlerPlugin) applyJSAfterResponse(
|
||||
reqHeadersMap map[string]any, reqBody []byte,
|
||||
bodyStr string, chunkStr *string,
|
||||
respHeaders http.Header, isStream bool, historyChunks []string,
|
||||
hostCallbackID string,
|
||||
) (string, *processedHeaders, bool, error) {
|
||||
program, err := getJSProgram(scriptPath)
|
||||
if err != nil {
|
||||
return bodyStr, nil, false, err
|
||||
}
|
||||
|
||||
engine := newJSEngine()
|
||||
engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID))
|
||||
if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil {
|
||||
return bodyStr, nil, false, errRun
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ function on_before_request(ctx) {
|
||||
"gpt-test",
|
||||
"openai",
|
||||
headers,
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSBeforeRequest() error = %v", errApply)
|
||||
@@ -85,6 +86,7 @@ function on_after_stream_response(ctx) {
|
||||
http.Header{},
|
||||
true,
|
||||
[]string{`data: {"choices":[{"delta":{"tool_calls":[{"index":0}]}}]}`},
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSAfterResponse() error = %v", errApply)
|
||||
@@ -123,6 +125,7 @@ function on_after_nonstream_response(ctx) {
|
||||
http.Header{},
|
||||
false,
|
||||
nil,
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSAfterResponse() error = %v", errApply)
|
||||
|
||||
@@ -6,13 +6,16 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestHostHTTPDoCallbackUsesHostHTTPClient(t *testing.T) {
|
||||
@@ -213,3 +216,44 @@ func TestHostStreamCallbacksEmitAndClose(t *testing.T) {
|
||||
t.Fatalf("stream remains open after close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostLogCallbackRestoresRegisteredRequestContext(t *testing.T) {
|
||||
host := New()
|
||||
ctx := logging.WithRequestID(context.Background(), "request-123")
|
||||
callbackID, closeCallback := host.openCallbackContext(ctx)
|
||||
defer closeCallback()
|
||||
|
||||
var out bytes.Buffer
|
||||
logger := log.StandardLogger()
|
||||
originalOut := logger.Out
|
||||
originalFormatter := logger.Formatter
|
||||
originalLevel := logger.Level
|
||||
log.SetOutput(&out)
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
DisableColors: true,
|
||||
DisableTimestamp: true,
|
||||
})
|
||||
log.SetLevel(log.InfoLevel)
|
||||
defer func() {
|
||||
log.SetOutput(originalOut)
|
||||
log.SetFormatter(originalFormatter)
|
||||
log.SetLevel(originalLevel)
|
||||
}()
|
||||
|
||||
rawReq, errMarshal := json.Marshal(rpcHostLogRequest{
|
||||
HostCallbackID: callbackID,
|
||||
Level: "info",
|
||||
Message: "plugin callback message",
|
||||
})
|
||||
if errMarshal != nil {
|
||||
t.Fatalf("marshal log request: %v", errMarshal)
|
||||
}
|
||||
if _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostLog, rawReq); errCall != nil {
|
||||
t.Fatalf("log callback error = %v", errCall)
|
||||
}
|
||||
|
||||
got := out.String()
|
||||
if !strings.Contains(got, "plugin callback message") || !strings.Contains(got, "request_id=request-123") {
|
||||
t.Fatalf("log output = %q, want message and request_id field", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -212,6 +213,47 @@ func TestInterceptorHelpersReturnErrorsWhenCallbackMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCInterceptorsIncludeHostCallbackID(t *testing.T) {
|
||||
client := &capturePluginClient{}
|
||||
adapter := &rpcPluginAdapter{
|
||||
host: New(),
|
||||
client: client,
|
||||
}
|
||||
|
||||
if _, errReq := adapter.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{Body: []byte("request")}); errReq != nil {
|
||||
t.Fatalf("InterceptRequest() error = %v", errReq)
|
||||
}
|
||||
var req rpcRequestInterceptRequest
|
||||
if errDecode := json.Unmarshal(client.requests[pluginabi.MethodRequestInterceptBefore], &req); errDecode != nil {
|
||||
t.Fatalf("decode request interceptor request: %v", errDecode)
|
||||
}
|
||||
if req.HostCallbackID == "" {
|
||||
t.Fatal("request interceptor host_callback_id is empty")
|
||||
}
|
||||
|
||||
if _, errResp := adapter.InterceptResponse(context.Background(), pluginapi.ResponseInterceptRequest{Body: []byte("response")}); errResp != nil {
|
||||
t.Fatalf("InterceptResponse() error = %v", errResp)
|
||||
}
|
||||
var resp rpcResponseInterceptRequest
|
||||
if errDecode := json.Unmarshal(client.requests[pluginabi.MethodResponseInterceptAfter], &resp); errDecode != nil {
|
||||
t.Fatalf("decode response interceptor request: %v", errDecode)
|
||||
}
|
||||
if resp.HostCallbackID == "" {
|
||||
t.Fatal("response interceptor host_callback_id is empty")
|
||||
}
|
||||
|
||||
if _, errChunk := adapter.InterceptStreamChunk(context.Background(), pluginapi.StreamChunkInterceptRequest{Body: []byte("chunk")}); errChunk != nil {
|
||||
t.Fatalf("InterceptStreamChunk() error = %v", errChunk)
|
||||
}
|
||||
var chunk rpcStreamChunkInterceptRequest
|
||||
if errDecode := json.Unmarshal(client.requests[pluginabi.MethodResponseInterceptStreamChunk], &chunk); errDecode != nil {
|
||||
t.Fatalf("decode stream chunk interceptor request: %v", errDecode)
|
||||
}
|
||||
if chunk.HostCallbackID == "" {
|
||||
t.Fatal("stream chunk interceptor host_callback_id is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizePluginRequestRemovesNonJSONMetadata(t *testing.T) {
|
||||
req := pluginapi.RequestInterceptRequest{
|
||||
Metadata: map[string]any{
|
||||
@@ -257,6 +299,19 @@ func TestSanitizePluginRequestRemovesNonJSONMetadata(t *testing.T) {
|
||||
if _, errMarshalExec := json.Marshal(sanitizePluginRequest(execReq)); errMarshalExec != nil {
|
||||
t.Fatalf("Marshal(sanitized executor request) error = %v", errMarshalExec)
|
||||
}
|
||||
|
||||
wrappedReq := rpcRequestInterceptRequest{
|
||||
RequestInterceptRequest: pluginapi.RequestInterceptRequest{
|
||||
Metadata: map[string]any{
|
||||
"keep": "value",
|
||||
"callback": func(string) {},
|
||||
},
|
||||
},
|
||||
HostCallbackID: "callback-1",
|
||||
}
|
||||
if _, errMarshalWrapped := json.Marshal(sanitizePluginRequest(wrappedReq)); errMarshalWrapped != nil {
|
||||
t.Fatalf("Marshal(sanitized wrapped request interceptor) error = %v", errMarshalWrapped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostApplyConfig_ReconfigureCalledOnReload(t *testing.T) {
|
||||
@@ -407,3 +462,17 @@ func TestSortRecordsPriorityDescendingAndIDTieBreak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type capturePluginClient struct {
|
||||
requests map[string][]byte
|
||||
}
|
||||
|
||||
func (c *capturePluginClient) Call(ctx context.Context, method string, request []byte) ([]byte, error) {
|
||||
if c.requests == nil {
|
||||
c.requests = make(map[string][]byte)
|
||||
}
|
||||
c.requests[method] = append([]byte(nil), request...)
|
||||
return marshalRPCResult(rpcEmptyResponse{})
|
||||
}
|
||||
|
||||
func (c *capturePluginClient) Shutdown() {}
|
||||
|
||||
@@ -169,6 +169,15 @@ func sanitizePluginRequest(request any) any {
|
||||
case pluginapi.StreamChunkInterceptRequest:
|
||||
req.Metadata = sanitizePluginMetadata(req.Metadata)
|
||||
return req
|
||||
case rpcRequestInterceptRequest:
|
||||
req.Metadata = sanitizePluginMetadata(req.Metadata)
|
||||
return req
|
||||
case rpcResponseInterceptRequest:
|
||||
req.Metadata = sanitizePluginMetadata(req.Metadata)
|
||||
return req
|
||||
case rpcStreamChunkInterceptRequest:
|
||||
req.Metadata = sanitizePluginMetadata(req.Metadata)
|
||||
return req
|
||||
case pluginapi.ExecutorHTTPRequest:
|
||||
req.HTTPClient = nil
|
||||
return req
|
||||
@@ -424,7 +433,12 @@ func (a *rpcPluginAdapter) NormalizeRequest(ctx context.Context, req pluginapi.R
|
||||
}
|
||||
|
||||
func (a *rpcPluginAdapter) InterceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
|
||||
return callPlugin[pluginapi.RequestInterceptResponse](ctx, a.client, pluginabi.MethodRequestInterceptBefore, req)
|
||||
callbackID, closeCallback := a.openHostCallbackContext(ctx)
|
||||
defer closeCallback()
|
||||
return callPlugin[pluginapi.RequestInterceptResponse](ctx, a.client, pluginabi.MethodRequestInterceptBefore, rpcRequestInterceptRequest{
|
||||
RequestInterceptRequest: req,
|
||||
HostCallbackID: callbackID,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *rpcPluginAdapter) TranslateResponse(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) {
|
||||
@@ -436,11 +450,21 @@ func (a rpcResponseNormalizer) NormalizeResponse(ctx context.Context, req plugin
|
||||
}
|
||||
|
||||
func (a *rpcPluginAdapter) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
|
||||
return callPlugin[pluginapi.ResponseInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptAfter, req)
|
||||
callbackID, closeCallback := a.openHostCallbackContext(ctx)
|
||||
defer closeCallback()
|
||||
return callPlugin[pluginapi.ResponseInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptAfter, rpcResponseInterceptRequest{
|
||||
ResponseInterceptRequest: req,
|
||||
HostCallbackID: callbackID,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *rpcPluginAdapter) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
|
||||
return callPlugin[pluginapi.StreamChunkInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptStreamChunk, req)
|
||||
callbackID, closeCallback := a.openHostCallbackContext(ctx)
|
||||
defer closeCallback()
|
||||
return callPlugin[pluginapi.StreamChunkInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptStreamChunk, rpcStreamChunkInterceptRequest{
|
||||
StreamChunkInterceptRequest: req,
|
||||
HostCallbackID: callbackID,
|
||||
})
|
||||
}
|
||||
|
||||
func (a rpcThinkingApplier) ApplyThinking(ctx context.Context, req pluginapi.ThinkingApplyRequest) (pluginapi.PayloadResponse, error) {
|
||||
|
||||
@@ -82,6 +82,21 @@ type rpcExecutorHTTPRequest struct {
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type rpcRequestInterceptRequest struct {
|
||||
pluginapi.RequestInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResponseInterceptRequest struct {
|
||||
pluginapi.ResponseInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type rpcStreamChunkInterceptRequest struct {
|
||||
pluginapi.StreamChunkInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type rpcThinkingApplyRequest struct {
|
||||
pluginapi.ThinkingApplyRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user