mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-24 21:09:16 +08:00
feat(pluginhost, jshandler): integrate HostCallbackID with interceptors and JS engine logging
- Added `HostCallbackID` to request, response, and stream chunk interceptors for enhanced context tracking. - Updated JavaScript engine to support custom console logging with `HostCallbackID` forwarding. - Introduced tests verifying proper integration of `HostCallbackID` in all interceptor flows and engine logging. - Enhanced logging and error handling for consistent callback-related logic implementation.
This commit is contained in:
@@ -92,6 +92,28 @@ type abiLifecycleRequest struct {
|
||||
PluginDir string `json:"plugin_dir,omitempty"`
|
||||
}
|
||||
|
||||
type abiRequestInterceptRequest struct {
|
||||
pluginapi.RequestInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiResponseInterceptRequest struct {
|
||||
pluginapi.ResponseInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiStreamChunkInterceptRequest struct {
|
||||
pluginapi.StreamChunkInterceptRequest
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
}
|
||||
|
||||
type abiHostLogRequest struct {
|
||||
HostCallbackID string `json:"host_callback_id,omitempty"`
|
||||
Level string `json:"level,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
type abiRegistration struct {
|
||||
SchemaVersion uint32 `json:"schema_version"`
|
||||
Metadata pluginapi.Metadata `json:"metadata"`
|
||||
@@ -192,25 +214,25 @@ func handleJSHandlerABIMethod(ctx context.Context, method string, request []byte
|
||||
defer done()
|
||||
switch method {
|
||||
case pluginabi.MethodRequestInterceptBefore:
|
||||
var req pluginapi.RequestInterceptRequest
|
||||
var req abiRequestInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptRequest(ctx, req)
|
||||
resp, errCall := p.interceptRequest(ctx, req.RequestInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
case pluginabi.MethodResponseInterceptAfter:
|
||||
var req pluginapi.ResponseInterceptRequest
|
||||
var req abiResponseInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptResponse(ctx, req)
|
||||
resp, errCall := p.interceptResponse(ctx, req.ResponseInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
case pluginabi.MethodResponseInterceptStreamChunk:
|
||||
var req pluginapi.StreamChunkInterceptRequest
|
||||
var req abiStreamChunkInterceptRequest
|
||||
if errDecode := json.Unmarshal(request, &req); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
resp, errCall := p.InterceptStreamChunk(ctx, req)
|
||||
resp, errCall := p.interceptStreamChunk(ctx, req.StreamChunkInterceptRequest, req.HostCallbackID)
|
||||
return abiOKEnvelopeWithError(resp, errCall)
|
||||
default:
|
||||
return abiErrorEnvelope("unknown_method", "unknown method: "+method), nil
|
||||
@@ -289,3 +311,85 @@ func writeABIResponse(response *C.cliproxy_buffer, raw []byte) {
|
||||
response.ptr = ptr
|
||||
response.len = C.size_t(len(raw))
|
||||
}
|
||||
|
||||
func newHostJSConsoleLogger(hostCallbackID string) jsConsoleLogger {
|
||||
return func(message string) error {
|
||||
if errLog := writeHostJSConsoleLog(hostCallbackID, message); errLog != nil {
|
||||
return defaultJSConsoleLogger(message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func writeHostJSConsoleLog(hostCallbackID string, message string) error {
|
||||
raw, errMarshal := json.Marshal(abiHostLogRequest{
|
||||
HostCallbackID: hostCallbackID,
|
||||
Level: "info",
|
||||
Message: "JS console log: " + message,
|
||||
Fields: map[string]any{
|
||||
"plugin_id": pluginName,
|
||||
},
|
||||
})
|
||||
if errMarshal != nil {
|
||||
return errMarshal
|
||||
}
|
||||
|
||||
rawResp, errCall := callHost(pluginabi.MethodHostLog, raw)
|
||||
if errCall != nil {
|
||||
return errCall
|
||||
}
|
||||
if len(rawResp) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp abiEnvelope
|
||||
if errDecode := json.Unmarshal(rawResp, &resp); errDecode != nil {
|
||||
return fmt.Errorf("decode host log response: %w", errDecode)
|
||||
}
|
||||
if !resp.OK {
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("host log failed: %s", resp.Error.Message)
|
||||
}
|
||||
return fmt.Errorf("host log failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func callHost(method string, payload []byte) ([]byte, error) {
|
||||
jsHandlerABIState.RLock()
|
||||
defer jsHandlerABIState.RUnlock()
|
||||
if jsHandlerABIState.host == nil {
|
||||
return nil, fmt.Errorf("host callback is unavailable")
|
||||
}
|
||||
|
||||
cMethod := C.CString(method)
|
||||
defer C.free(unsafe.Pointer(cMethod))
|
||||
|
||||
var cPayload unsafe.Pointer
|
||||
if len(payload) > 0 {
|
||||
cPayload = C.CBytes(payload)
|
||||
if cPayload == nil {
|
||||
return nil, fmt.Errorf("allocate host callback payload")
|
||||
}
|
||||
defer C.free(cPayload)
|
||||
}
|
||||
|
||||
var response C.cliproxy_buffer
|
||||
rc := C.jshandler_call_host(
|
||||
jsHandlerABIState.host,
|
||||
cMethod,
|
||||
(*C.uint8_t)(cPayload),
|
||||
C.size_t(len(payload)),
|
||||
&response,
|
||||
)
|
||||
var out []byte
|
||||
if response.ptr != nil && response.len > 0 {
|
||||
out = C.GoBytes(response.ptr, C.int(response.len))
|
||||
}
|
||||
if response.ptr != nil {
|
||||
C.jshandler_free_host_buffer(jsHandlerABIState.host, response.ptr, response.len)
|
||||
}
|
||||
if rc != 0 {
|
||||
return nil, fmt.Errorf("host callback %s returned %d: %s", method, int(rc), string(out))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,27 +14,43 @@ import (
|
||||
)
|
||||
|
||||
type jsEngine struct {
|
||||
vm *goja.Runtime
|
||||
vm *goja.Runtime
|
||||
consoleLogger jsConsoleLogger
|
||||
}
|
||||
|
||||
const maxJSScriptBytes = 8 * 1024 * 1024
|
||||
|
||||
func newJSEngine() *jsEngine {
|
||||
type jsConsoleLogger func(message string) error
|
||||
|
||||
func newJSEngine(loggers ...jsConsoleLogger) *jsEngine {
|
||||
consoleLogger := defaultJSConsoleLogger
|
||||
if len(loggers) > 0 && loggers[0] != nil {
|
||||
consoleLogger = loggers[0]
|
||||
}
|
||||
engine := &jsEngine{
|
||||
vm: goja.New(),
|
||||
vm: goja.New(),
|
||||
consoleLogger: consoleLogger,
|
||||
}
|
||||
engine.initConsole()
|
||||
return engine
|
||||
}
|
||||
|
||||
func defaultJSConsoleLogger(message string) error {
|
||||
log.Info("JS console log: ", message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (engine *jsEngine) initConsole() {
|
||||
console := engine.vm.NewObject()
|
||||
consoleLogWrapper := func(call goja.FunctionCall) goja.Value {
|
||||
args := make([]interface{}, len(call.Arguments))
|
||||
args := make([]string, len(call.Arguments))
|
||||
for i, arg := range call.Arguments {
|
||||
args[i] = arg.Export()
|
||||
args[i] = fmt.Sprint(arg.Export())
|
||||
}
|
||||
message := strings.Join(args, " ")
|
||||
if errLog := engine.consoleLogger(message); errLog != nil {
|
||||
defaultJSConsoleLogger(message)
|
||||
}
|
||||
log.Info("JS console log: ", fmt.Sprint(args...))
|
||||
return goja.Undefined()
|
||||
}
|
||||
_ = console.Set("log", consoleLogWrapper)
|
||||
|
||||
@@ -1,10 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestConsoleLogWritesToLogger(t *testing.T) {
|
||||
var out bytes.Buffer
|
||||
logger := log.StandardLogger()
|
||||
originalOut := logger.Out
|
||||
originalFormatter := logger.Formatter
|
||||
originalLevel := logger.Level
|
||||
log.SetOutput(&out)
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
DisableColors: true,
|
||||
DisableTimestamp: true,
|
||||
})
|
||||
log.SetLevel(log.InfoLevel)
|
||||
defer func() {
|
||||
log.SetOutput(originalOut)
|
||||
log.SetFormatter(originalFormatter)
|
||||
log.SetLevel(originalLevel)
|
||||
}()
|
||||
|
||||
engine := newJSEngine()
|
||||
_, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`)
|
||||
if errRun != nil {
|
||||
t.Fatalf("RunString() error = %v", errRun)
|
||||
}
|
||||
|
||||
got := out.String()
|
||||
if !strings.Contains(got, "JS console log: alpha 42 true") {
|
||||
t.Fatalf("console.log output = %q, want logger output with JS message", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsoleLogUsesConfiguredLogger(t *testing.T) {
|
||||
var messages []string
|
||||
engine := newJSEngine(func(message string) error {
|
||||
messages = append(messages, message)
|
||||
return nil
|
||||
})
|
||||
_, errRun := engine.vm.RunString(`console.log("alpha", 42, true);`)
|
||||
if errRun != nil {
|
||||
t.Fatalf("RunString() error = %v", errRun)
|
||||
}
|
||||
if len(messages) != 1 || messages[0] != "alpha 42 true" {
|
||||
t.Fatalf("console log messages = %#v, want formatted message", messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInterruptTimerClearsExpiredInterrupt(t *testing.T) {
|
||||
engine := newJSEngine()
|
||||
timer, done := engine.startInterruptTimer(time.Nanosecond)
|
||||
|
||||
@@ -45,6 +45,10 @@ func (p *jsHandlerPlugin) allScriptPaths() []string {
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest) (pluginapi.RequestInterceptResponse, error) {
|
||||
return p.interceptRequest(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptRequest(ctx context.Context, req pluginapi.RequestInterceptRequest, hostCallbackID string) (pluginapi.RequestInterceptResponse, error) {
|
||||
resp := pluginapi.RequestInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -60,7 +64,7 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re
|
||||
if scriptPath == "" {
|
||||
continue
|
||||
}
|
||||
processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers)
|
||||
processed, cleared, errJS := p.applyJSBeforeRequest(scriptPath, []byte(body), req.Model, req.SourceFormat, headers, hostCallbackID)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS request interceptor [%s]: %v", scriptPath, errJS)
|
||||
continue
|
||||
@@ -78,6 +82,10 @@ func (p *jsHandlerPlugin) InterceptRequest(ctx context.Context, req pluginapi.Re
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest) (pluginapi.ResponseInterceptResponse, error) {
|
||||
return p.interceptResponse(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptResponse(ctx context.Context, req pluginapi.ResponseInterceptRequest, hostCallbackID string) (pluginapi.ResponseInterceptResponse, error) {
|
||||
resp := pluginapi.ResponseInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -98,6 +106,7 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R
|
||||
scriptPath, req.Model, req.SourceFormat,
|
||||
reqHeadersMap, req.RequestBody,
|
||||
bodyStr, nil, respHeaders, false, nil,
|
||||
hostCallbackID,
|
||||
)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS response interceptor [%s]: %v", scriptPath, errJS)
|
||||
@@ -121,6 +130,10 @@ func (p *jsHandlerPlugin) InterceptResponse(ctx context.Context, req pluginapi.R
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest) (pluginapi.StreamChunkInterceptResponse, error) {
|
||||
return p.interceptStreamChunk(ctx, req, "")
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) interceptStreamChunk(ctx context.Context, req pluginapi.StreamChunkInterceptRequest, hostCallbackID string) (pluginapi.StreamChunkInterceptResponse, error) {
|
||||
resp := pluginapi.StreamChunkInterceptResponse{}
|
||||
scriptPaths := p.allScriptPaths()
|
||||
if len(scriptPaths) == 0 {
|
||||
@@ -156,6 +169,7 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap
|
||||
scriptPath, req.Model, req.SourceFormat,
|
||||
reqHeadersMap, req.RequestBody,
|
||||
"", chunkPtr, respHeaders, !isHeaderInit, historyStrings,
|
||||
hostCallbackID,
|
||||
)
|
||||
if errJS != nil {
|
||||
log.Warnf("failed to execute JS stream chunk interceptor [%s]: %v", scriptPath, errJS)
|
||||
@@ -183,13 +197,13 @@ func (p *jsHandlerPlugin) InterceptStreamChunk(ctx context.Context, req pluginap
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header) ([]byte, []string, error) {
|
||||
func (p *jsHandlerPlugin) applyJSBeforeRequest(scriptPath string, payloadBytes []byte, model, protocol string, headers http.Header, hostCallbackID string) ([]byte, []string, error) {
|
||||
program, err := getJSProgram(scriptPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
engine := newJSEngine()
|
||||
engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID))
|
||||
if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil {
|
||||
return nil, nil, errRun
|
||||
}
|
||||
@@ -246,13 +260,14 @@ func (p *jsHandlerPlugin) applyJSAfterResponse(
|
||||
reqHeadersMap map[string]any, reqBody []byte,
|
||||
bodyStr string, chunkStr *string,
|
||||
respHeaders http.Header, isStream bool, historyChunks []string,
|
||||
hostCallbackID string,
|
||||
) (string, *processedHeaders, bool, error) {
|
||||
program, err := getJSProgram(scriptPath)
|
||||
if err != nil {
|
||||
return bodyStr, nil, false, err
|
||||
}
|
||||
|
||||
engine := newJSEngine()
|
||||
engine := newJSEngine(newHostJSConsoleLogger(hostCallbackID))
|
||||
if errRun := engine.runProgram(program, p.cfg.Timeout); errRun != nil {
|
||||
return bodyStr, nil, false, errRun
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ function on_before_request(ctx) {
|
||||
"gpt-test",
|
||||
"openai",
|
||||
headers,
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSBeforeRequest() error = %v", errApply)
|
||||
@@ -85,6 +86,7 @@ function on_after_stream_response(ctx) {
|
||||
http.Header{},
|
||||
true,
|
||||
[]string{`data: {"choices":[{"delta":{"tool_calls":[{"index":0}]}}]}`},
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSAfterResponse() error = %v", errApply)
|
||||
@@ -123,6 +125,7 @@ function on_after_nonstream_response(ctx) {
|
||||
http.Header{},
|
||||
false,
|
||||
nil,
|
||||
"",
|
||||
)
|
||||
if errApply != nil {
|
||||
t.Fatalf("applyJSAfterResponse() error = %v", errApply)
|
||||
|
||||
Reference in New Issue
Block a user