From 2aeb41cecfa11fe032fd20b3abd7e1569ca7721f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 9 Jun 2026 14:36:42 +0800 Subject: [PATCH] feat(pluginhost, jshandler): integrate HostCallbackID with interceptors and JS engine logging - Added `HostCallbackID` to request, response, and stream chunk interceptors for enhanced context tracking. - Updated JavaScript engine to support custom console logging with `HostCallbackID` forwarding. - Introduced tests verifying proper integration of `HostCallbackID` in all interceptor flows and engine logging. - Enhanced logging and error handling for consistent callback-related logic implementation. --- examples/plugin/jshandler/abi.go | 116 +++++++++++++++++- examples/plugin/jshandler/engine.go | 29 ++++- examples/plugin/jshandler/engine_test.go | 49 ++++++++ examples/plugin/jshandler/interceptor.go | 23 +++- examples/plugin/jshandler/interceptor_test.go | 3 + internal/pluginhost/host_callbacks_test.go | 44 +++++++ internal/pluginhost/host_test.go | 69 +++++++++++ internal/pluginhost/rpc_client.go | 30 ++++- internal/pluginhost/rpc_schema.go | 15 +++ 9 files changed, 359 insertions(+), 19 deletions(-) diff --git a/examples/plugin/jshandler/abi.go b/examples/plugin/jshandler/abi.go index d4e3b39e9..59c30c88a 100644 --- a/examples/plugin/jshandler/abi.go +++ b/examples/plugin/jshandler/abi.go @@ -92,6 +92,28 @@ type abiLifecycleRequest struct { PluginDir string `json:"plugin_dir,omitempty"` } +type abiRequestInterceptRequest struct { + pluginapi.RequestInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type abiResponseInterceptRequest struct { + pluginapi.ResponseInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type abiStreamChunkInterceptRequest struct { + pluginapi.StreamChunkInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type abiHostLogRequest struct { + HostCallbackID string `json:"host_callback_id,omitempty"` + Level string `json:"level,omitempty"` + Message string `json:"message,omitempty"` + Fields map[string]any `json:"fields,omitempty"` +} + type abiRegistration struct { SchemaVersion uint32 `json:"schema_version"` Metadata pluginapi.Metadata `json:"metadata"` @@ -192,25 +214,25 @@ func handleJSHandlerABIMethod(ctx context.Context, method string, request []byte defer done() switch method { case pluginabi.MethodRequestInterceptBefore: - var req pluginapi.RequestInterceptRequest + var req abiRequestInterceptRequest if errDecode := json.Unmarshal(request, &req); errDecode != nil { return nil, errDecode } - resp, errCall := p.InterceptRequest(ctx, req) + resp, errCall := p.interceptRequest(ctx, req.RequestInterceptRequest, req.HostCallbackID) return abiOKEnvelopeWithError(resp, errCall) case pluginabi.MethodResponseInterceptAfter: - var req pluginapi.ResponseInterceptRequest + var req abiResponseInterceptRequest if errDecode := json.Unmarshal(request, &req); errDecode != nil { return nil, errDecode } - resp, errCall := p.InterceptResponse(ctx, req) + resp, errCall := p.interceptResponse(ctx, req.ResponseInterceptRequest, req.HostCallbackID) return abiOKEnvelopeWithError(resp, errCall) case pluginabi.MethodResponseInterceptStreamChunk: - var req pluginapi.StreamChunkInterceptRequest + var req abiStreamChunkInterceptRequest if errDecode := json.Unmarshal(request, &req); errDecode != nil { return nil, errDecode } - resp, errCall := p.InterceptStreamChunk(ctx, req) + resp, errCall := p.interceptStreamChunk(ctx, req.StreamChunkInterceptRequest, req.HostCallbackID) return abiOKEnvelopeWithError(resp, errCall) default: return abiErrorEnvelope("unknown_method", "unknown method: "+method), nil @@ -289,3 +311,85 @@ func writeABIResponse(response *C.cliproxy_buffer, raw []byte) { response.ptr = ptr response.len = C.size_t(len(raw)) } + +func newHostJSConsoleLogger(hostCallbackID string) jsConsoleLogger { + return func(message string) error { + if errLog := writeHostJSConsoleLog(hostCallbackID, message); errLog != nil { + return defaultJSConsoleLogger(message) + } + return nil + } +} + +func writeHostJSConsoleLog(hostCallbackID string, message string) error { + raw, errMarshal := json.Marshal(abiHostLogRequest{ + HostCallbackID: hostCallbackID, + Level: "info", + Message: "JS console log: " + message, + Fields: map[string]any{ + "plugin_id": pluginName, + }, + }) + if errMarshal != nil { + return errMarshal + } + + rawResp, errCall := callHost(pluginabi.MethodHostLog, raw) + if errCall != nil { + return errCall + } + if len(rawResp) == 0 { + return nil + } + var resp abiEnvelope + if errDecode := json.Unmarshal(rawResp, &resp); errDecode != nil { + return fmt.Errorf("decode host log response: %w", errDecode) + } + if !resp.OK { + if resp.Error != nil { + return fmt.Errorf("host log failed: %s", resp.Error.Message) + } + return fmt.Errorf("host log failed") + } + return nil +} + +func callHost(method string, payload []byte) ([]byte, error) { + jsHandlerABIState.RLock() + defer jsHandlerABIState.RUnlock() + if jsHandlerABIState.host == nil { + return nil, fmt.Errorf("host callback is unavailable") + } + + cMethod := C.CString(method) + defer C.free(unsafe.Pointer(cMethod)) + + var cPayload unsafe.Pointer + if len(payload) > 0 { + cPayload = C.CBytes(payload) + if cPayload == nil { + return nil, fmt.Errorf("allocate host callback payload") + } + defer C.free(cPayload) + } + + var response C.cliproxy_buffer + rc := C.jshandler_call_host( + jsHandlerABIState.host, + cMethod, + (*C.uint8_t)(cPayload), + C.size_t(len(payload)), + &response, + ) + var out []byte + if response.ptr != nil && response.len > 0 { + out = C.GoBytes(response.ptr, C.int(response.len)) + } + if response.ptr != nil { + C.jshandler_free_host_buffer(jsHandlerABIState.host, response.ptr, response.len) + } + if rc != 0 { + return nil, fmt.Errorf("host callback %s returned %d: %s", method, int(rc), string(out)) + } + return out, nil +} diff --git a/examples/plugin/jshandler/engine.go b/examples/plugin/jshandler/engine.go index 8da181ff6..5f076cd12 100644 --- a/examples/plugin/jshandler/engine.go +++ b/examples/plugin/jshandler/engine.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "time" @@ -13,27 +14,43 @@ import ( ) type jsEngine struct { - vm *goja.Runtime + vm *goja.Runtime + consoleLogger jsConsoleLogger } const maxJSScriptBytes = 8 * 1024 * 1024 -func newJSEngine() *jsEngine { +type jsConsoleLogger func(message string) error + +func newJSEngine(loggers ...jsConsoleLogger) *jsEngine { + consoleLogger := defaultJSConsoleLogger + if len(loggers) > 0 && loggers[0] != nil { + consoleLogger = loggers[0] + } engine := &jsEngine{ - vm: goja.New(), + vm: goja.New(), + consoleLogger: consoleLogger, } engine.initConsole() return engine } +func defaultJSConsoleLogger(message string) error { + log.Info("JS console log: ", message) + return nil +} + func (engine *jsEngine) initConsole() { console := engine.vm.NewObject() consoleLogWrapper := func(call goja.FunctionCall) goja.Value { - args := make([]interface{}, len(call.Arguments)) + args := make([]string, len(call.Arguments)) for i, arg := range call.Arguments { - args[i] = arg.Export() + args[i] = fmt.Sprint(arg.Export()) + } + message := strings.Join(args, " ") + if errLog := engine.consoleLogger(message); errLog != nil { + defaultJSConsoleLogger(message) } - log.Info("JS console log: ", fmt.Sprint(args...)) return goja.Undefined() } _ = console.Set("log", consoleLogWrapper) diff --git a/examples/plugin/jshandler/engine_test.go b/examples/plugin/jshandler/engine_test.go index 33bbd8c36..45c5f8d3d 100644 --- a/examples/plugin/jshandler/engine_test.go +++ b/examples/plugin/jshandler/engine_test.go @@ -1,10 +1,59 @@ package main import ( + "bytes" + "strings" "testing" "time" + + log "github.com/sirupsen/logrus" ) +func TestConsoleLogWritesToLogger(t *testing.T) { + var out bytes.Buffer + logger := log.StandardLogger() + originalOut := logger.Out + originalFormatter := logger.Formatter + originalLevel := logger.Level + log.SetOutput(&out) + log.SetFormatter(&log.TextFormatter{ + DisableColors: true, + DisableTimestamp: true, + }) + log.SetLevel(log.InfoLevel) + defer func() { + log.SetOutput(originalOut) + log.SetFormatter(originalFormatter) + log.SetLevel(originalLevel) + }() + + engine := newJSEngine() + _, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`) + if errRun != nil { + t.Fatalf("RunString() error = %v", errRun) + } + + got := out.String() + if !strings.Contains(got, "JS console log: alpha 42 true") { + t.Fatalf("console.log output = %q, want logger output with JS message", got) + } +} + +func TestConsoleLogUsesConfiguredLogger(t *testing.T) { + var messages []string + engine := newJSEngine(func(message string) error { + messages = append(messages, message) + return nil + }) + _, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`) + if errRun != nil { + t.Fatalf("RunString() error = %v", errRun) + } + if len(messages) != 1 || messages[0] != "alpha 42 true" { + t.Fatalf("console log messages = %#v, want formatted message", messages) + } +} + func TestStopInterruptTimerClearsExpiredInterrupt(t *testing.T) { engine := newJSEngine() timer, done := engine.startInterruptTimer(time.Nanosecond) diff --git a/examples/plugin/jshandler/interceptor.go b/examples/plugin/jshandler/interceptor.go index 347a59ee3..d866b7cf2 100644 --- a/examples/plugin/jshandler/interceptor.go +++ b/examples/plugin/jshandler/interceptor.go @@ -45,6 +45,10 @@ func (p *jsHandlerPlugin) allScriptPaths() []string { } func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) { + return p.interceptRequest(ctx, req, "") +} + +func (p *jsHandlerPlugin) interceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest, hostCallbackID string) (pluginapi.RequestInterceptResponse, error) { resp := pluginapi.RequestInterceptResponse{} scriptPaths := p.allScriptPaths() if len(scriptPaths) == 0 { @@ -60,7 +64,7 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re if scriptPath == "" { continue } - processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers) + processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers, hostCallbackID) if errJS != nil { log.Warnf("failed to execute JS request interceptor [%s]: %v", scriptPath, errJS) continue @@ -78,6 +82,10 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re } func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) { + return p.interceptResponse(ctx, req, "") +} + +func (p *jsHandlerPlugin) interceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest, hostCallbackID string) (pluginapi.ResponseInterceptResponse, error) { resp := pluginapi.ResponseInterceptResponse{} scriptPaths := p.allScriptPaths() if len(scriptPaths) == 0 { @@ -98,6 +106,7 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R scriptPath, req.Model, req.SourceFormat, reqHeadersMap, req.RequestBody, bodyStr, nil, respHeaders, false, nil, + hostCallbackID, ) if errJS != nil { log.Warnf("failed to execute JS response interceptor [%s]: %v", scriptPath, errJS) @@ -121,6 +130,10 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R } func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) { + return p.interceptStreamChunk(ctx, req, "") +} + +func (p *jsHandlerPlugin) interceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest, hostCallbackID string) (pluginapi.StreamChunkInterceptResponse, error) { resp := pluginapi.StreamChunkInterceptResponse{} scriptPaths := p.allScriptPaths() if len(scriptPaths) == 0 { @@ -156,6 +169,7 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap scriptPath, req.Model, req.SourceFormat, reqHeadersMap, req.RequestBody, "", chunkPtr, respHeaders, !isHeaderInit, historyStrings, + hostCallbackID, ) if errJS != nil { log.Warnf("failed to execute JS stream chunk interceptor [%s]: %v", scriptPath, errJS) @@ -183,13 +197,13 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap return resp, nil } -func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header) ([]byte, []string, error) { +func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header, hostCallbackID string) ([]byte, []string, error) { program, err := getJSProgram(scriptPath) if err != nil { return nil, nil, err } - engine := newJSEngine() + engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID)) if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil { return nil, nil, errRun } @@ -246,13 +260,14 @@ func (p *jsHandlerPlugin) applyJSAfterResponse( reqHeadersMap map[string]any, reqBody []byte, bodyStr string, chunkStr *string, respHeaders http.Header, isStream bool, historyChunks []string, + hostCallbackID string, ) (string, *processedHeaders, bool, error) { program, err := getJSProgram(scriptPath) if err != nil { return bodyStr, nil, false, err } - engine := newJSEngine() + engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID)) if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil { return bodyStr, nil, false, errRun } diff --git a/examples/plugin/jshandler/interceptor_test.go b/examples/plugin/jshandler/interceptor_test.go index 6d7b8481c..374473629 100644 --- a/examples/plugin/jshandler/interceptor_test.go +++ b/examples/plugin/jshandler/interceptor_test.go @@ -31,6 +31,7 @@ function on_before_request(ctx) { "gpt-test", "openai", headers, + "", ) if errApply != nil { t.Fatalf("applyJSBeforeRequest() error = %v", errApply) @@ -85,6 +86,7 @@ function on_after_stream_response(ctx) { http.Header{}, true, []string{`data: {"choices":[{"delta":{"tool_calls":[{"index":0}]}}]}`}, + "", ) if errApply != nil { t.Fatalf("applyJSAfterResponse() error = %v", errApply) @@ -123,6 +125,7 @@ function on_after_nonstream_response(ctx) { http.Header{}, false, nil, + "", ) if errApply != nil { t.Fatalf("applyJSAfterResponse() error = %v", errApply) diff --git a/internal/pluginhost/host_callbacks_test.go b/internal/pluginhost/host_callbacks_test.go index 50e58c760..a28f33da0 100644 --- a/internal/pluginhost/host_callbacks_test.go +++ b/internal/pluginhost/host_callbacks_test.go @@ -6,13 +6,16 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" + log "github.com/sirupsen/logrus" ) func TestHostHTTPDoCallbackUsesHostHTTPClient(t *testing.T) { @@ -213,3 +216,44 @@ func TestHostStreamCallbacksEmitAndClose(t *testing.T) { t.Fatalf("stream remains open after close") } } + +func TestHostLogCallbackRestoresRegisteredRequestContext(t *testing.T) { + host := New() + ctx := logging.WithRequestID(context.Background(), "request-123") + callbackID, closeCallback := host.openCallbackContext(ctx) + defer closeCallback() + + var out bytes.Buffer + logger := log.StandardLogger() + originalOut := logger.Out + originalFormatter := logger.Formatter + originalLevel := logger.Level + log.SetOutput(&out) + log.SetFormatter(&log.TextFormatter{ + DisableColors: true, + DisableTimestamp: true, + }) + log.SetLevel(log.InfoLevel) + defer func() { + log.SetOutput(originalOut) + log.SetFormatter(originalFormatter) + log.SetLevel(originalLevel) + }() + + rawReq, errMarshal := json.Marshal(rpcHostLogRequest{ + HostCallbackID: callbackID, + Level: "info", + Message: "plugin callback message", + }) + if errMarshal != nil { + t.Fatalf("marshal log request: %v", errMarshal) + } + if _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostLog, rawReq); errCall != nil { + t.Fatalf("log callback error = %v", errCall) + } + + got := out.String() + if !strings.Contains(got, "plugin callback message") || !strings.Contains(got, "request_id=request-123") { + t.Fatalf("log output = %q, want message and request_id field", got) + } +} diff --git a/internal/pluginhost/host_test.go b/internal/pluginhost/host_test.go index fd65a11c8..90d2a7610 100644 --- a/internal/pluginhost/host_test.go +++ b/internal/pluginhost/host_test.go @@ -7,6 +7,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" "github.com/tidwall/gjson" ) @@ -212,6 +213,47 @@ func TestInterceptorHelpersReturnErrorsWhenCallbackMissing(t *testing.T) { } } +func TestRPCInterceptorsIncludeHostCallbackID(t *testing.T) { + client := &capturePluginClient{} + adapter := &rpcPluginAdapter{ + host: New(), + client: client, + } + + if _, errReq := adapter.InterceptRequest(context.Background(), pluginapi.RequestInterceptRequest{Body: []byte("request")}); errReq != nil { + t.Fatalf("InterceptRequest() error = %v", errReq) + } + var req rpcRequestInterceptRequest + if errDecode := json.Unmarshal(client.requests[pluginabi.MethodRequestInterceptBefore], &req); errDecode != nil { + t.Fatalf("decode request interceptor request: %v", errDecode) + } + if req.HostCallbackID == "" { + t.Fatal("request interceptor host_callback_id is empty") + } + + if _, errResp := adapter.InterceptResponse(context.Background(), pluginapi.ResponseInterceptRequest{Body: []byte("response")}); errResp != nil { + t.Fatalf("InterceptResponse() error = %v", errResp) + } + var resp rpcResponseInterceptRequest + if errDecode := json.Unmarshal(client.requests[pluginabi.MethodResponseInterceptAfter], &resp); errDecode != nil { + t.Fatalf("decode response interceptor request: %v", errDecode) + } + if resp.HostCallbackID == "" { + t.Fatal("response interceptor host_callback_id is empty") + } + + if _, errChunk := adapter.InterceptStreamChunk(context.Background(), pluginapi.StreamChunkInterceptRequest{Body: []byte("chunk")}); errChunk != nil { + t.Fatalf("InterceptStreamChunk() error = %v", errChunk) + } + var chunk rpcStreamChunkInterceptRequest + if errDecode := json.Unmarshal(client.requests[pluginabi.MethodResponseInterceptStreamChunk], &chunk); errDecode != nil { + t.Fatalf("decode stream chunk interceptor request: %v", errDecode) + } + if chunk.HostCallbackID == "" { + t.Fatal("stream chunk interceptor host_callback_id is empty") + } +} + func TestSanitizePluginRequestRemovesNonJSONMetadata(t *testing.T) { req := pluginapi.RequestInterceptRequest{ Metadata: map[string]any{ @@ -257,6 +299,19 @@ func TestSanitizePluginRequestRemovesNonJSONMetadata(t *testing.T) { if _, errMarshalExec := json.Marshal(sanitizePluginRequest(execReq)); errMarshalExec != nil { t.Fatalf("Marshal(sanitized executor request) error = %v", errMarshalExec) } + + wrappedReq := rpcRequestInterceptRequest{ + RequestInterceptRequest: pluginapi.RequestInterceptRequest{ + Metadata: map[string]any{ + "keep": "value", + "callback": func(string) {}, + }, + }, + HostCallbackID: "callback-1", + } + if _, errMarshalWrapped := json.Marshal(sanitizePluginRequest(wrappedReq)); errMarshalWrapped != nil { + t.Fatalf("Marshal(sanitized wrapped request interceptor) error = %v", errMarshalWrapped) + } } func TestHostApplyConfig_ReconfigureCalledOnReload(t *testing.T) { @@ -407,3 +462,17 @@ func TestSortRecordsPriorityDescendingAndIDTieBreak(t *testing.T) { } } } + +type capturePluginClient struct { + requests map[string][]byte +} + +func (c *capturePluginClient) Call(ctx context.Context, method string, request []byte) ([]byte, error) { + if c.requests == nil { + c.requests = make(map[string][]byte) + } + c.requests[method] = append([]byte(nil), request...) + return marshalRPCResult(rpcEmptyResponse{}) +} + +func (c *capturePluginClient) Shutdown() {} diff --git a/internal/pluginhost/rpc_client.go b/internal/pluginhost/rpc_client.go index 0d3817c28..d84adb8f2 100644 --- a/internal/pluginhost/rpc_client.go +++ b/internal/pluginhost/rpc_client.go @@ -169,6 +169,15 @@ func sanitizePluginRequest(request any) any { case pluginapi.StreamChunkInterceptRequest: req.Metadata = sanitizePluginMetadata(req.Metadata) return req + case rpcRequestInterceptRequest: + req.Metadata = sanitizePluginMetadata(req.Metadata) + return req + case rpcResponseInterceptRequest: + req.Metadata = sanitizePluginMetadata(req.Metadata) + return req + case rpcStreamChunkInterceptRequest: + req.Metadata = sanitizePluginMetadata(req.Metadata) + return req case pluginapi.ExecutorHTTPRequest: req.HTTPClient = nil return req @@ -424,7 +433,12 @@ func (a *rpcPluginAdapter) NormalizeRequest(ctx context.Context, req pluginapi.R } func (a *rpcPluginAdapter) InterceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) { - return callPlugin[pluginapi.RequestInterceptResponse](ctx, a.client, pluginabi.MethodRequestInterceptBefore, req) + callbackID, closeCallback := a.openHostCallbackContext(ctx) + defer closeCallback() + return callPlugin[pluginapi.RequestInterceptResponse](ctx, a.client, pluginabi.MethodRequestInterceptBefore, rpcRequestInterceptRequest{ + RequestInterceptRequest: req, + HostCallbackID: callbackID, + }) } func (a *rpcPluginAdapter) TranslateResponse(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) { @@ -436,11 +450,21 @@ func (a rpcResponseNormalizer) NormalizeResponse(ctx context.Context, req plugin } func (a *rpcPluginAdapter) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) { - return callPlugin[pluginapi.ResponseInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptAfter, req) + callbackID, closeCallback := a.openHostCallbackContext(ctx) + defer closeCallback() + return callPlugin[pluginapi.ResponseInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptAfter, rpcResponseInterceptRequest{ + ResponseInterceptRequest: req, + HostCallbackID: callbackID, + }) } func (a *rpcPluginAdapter) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) { - return callPlugin[pluginapi.StreamChunkInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptStreamChunk, req) + callbackID, closeCallback := a.openHostCallbackContext(ctx) + defer closeCallback() + return callPlugin[pluginapi.StreamChunkInterceptResponse](ctx, a.client, pluginabi.MethodResponseInterceptStreamChunk, rpcStreamChunkInterceptRequest{ + StreamChunkInterceptRequest: req, + HostCallbackID: callbackID, + }) } func (a rpcThinkingApplier) ApplyThinking(ctx context.Context, req pluginapi.ThinkingApplyRequest) (pluginapi.PayloadResponse, error) { diff --git a/internal/pluginhost/rpc_schema.go b/internal/pluginhost/rpc_schema.go index 61f474d44..eb2963fb1 100644 --- a/internal/pluginhost/rpc_schema.go +++ b/internal/pluginhost/rpc_schema.go @@ -82,6 +82,21 @@ type rpcExecutorHTTPRequest struct { HostCallbackID string `json:"host_callback_id,omitempty"` } +type rpcRequestInterceptRequest struct { + pluginapi.RequestInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type rpcResponseInterceptRequest struct { + pluginapi.ResponseInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type rpcStreamChunkInterceptRequest struct { + pluginapi.StreamChunkInterceptRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + type rpcThinkingApplyRequest struct { pluginapi.ThinkingApplyRequest HostCallbackID string `json:"host_callback_id,omitempty"`