mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-13 19:11:40 +08:00
- Updated host model callback logic to skip originating plugin's interceptors during nested model executions. - Added `SkipInterceptorPluginID` field to plugin API structs for controlling interceptor bypass behavior. - Introduced supporting logic in host API handlers, plugin host registry, and callback contexts to identify and skip specific plugins. - Enhanced unit tests across plugin host, API handlers, and execution paths to verify interceptor skipping behavior and plugin isolation. - Revised documentation to clarify non-recursive behavior of host model callbacks and the use of `SkipInterceptorPluginID`.
750 lines
26 KiB
Go
750 lines
26 KiB
Go
package pluginhost
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type fakeHostModelExecutor struct {
|
|
executeModel func(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage)
|
|
executeModelStream func(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage)
|
|
}
|
|
|
|
func (e *fakeHostModelExecutor) ExecuteModel(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) {
|
|
return e.executeModel(ctx, req)
|
|
}
|
|
|
|
func (e *fakeHostModelExecutor) ExecuteModelStream(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
return e.executeModelStream(ctx, req)
|
|
}
|
|
|
|
func TestHostHTTPDoCallbackUsesHostHTTPClient(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
t.Fatalf("method = %s, want POST", r.Method)
|
|
}
|
|
w.Header().Set("X-Test", "ok")
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
req := pluginapi.HTTPRequest{
|
|
Method: http.MethodPost,
|
|
URL: server.URL,
|
|
Body: []byte(`{"request":true}`),
|
|
}
|
|
rawReq, errMarshal := json.Marshal(req)
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
|
|
rawResp, errCall := New().callFromPlugin(context.Background(), pluginabi.MethodHostHTTPDo, rawReq)
|
|
if errCall != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", errCall)
|
|
}
|
|
|
|
resp, errDecode := decodeRPCEnvelope[pluginapi.HTTPResponse](rawResp)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode response: %v", errDecode)
|
|
}
|
|
if resp.StatusCode != http.StatusOK || string(resp.Body) != `{"ok":true}` {
|
|
t.Fatalf("response = %#v, want status 200 body", resp)
|
|
}
|
|
if resp.Headers.Get("X-Test") != "ok" {
|
|
t.Fatalf("X-Test = %q, want ok", resp.Headers.Get("X-Test"))
|
|
}
|
|
}
|
|
|
|
func TestHostHTTPDoCallbackRestoresRegisteredRequestContext(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
|
|
|
host := New()
|
|
host.mu.Lock()
|
|
host.runtimeConfig = &config.Config{SDKConfig: config.SDKConfig{RequestLog: true}}
|
|
host.mu.Unlock()
|
|
callbackID, closeCallback := host.openCallbackContext(ctx)
|
|
defer closeCallback()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Context().Err() != nil {
|
|
t.Fatalf("request context error = %v", r.Context().Err())
|
|
}
|
|
w.Header().Set("X-Upstream", "ok")
|
|
_, _ = w.Write([]byte("upstream-body"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostHTTPRequest{
|
|
HostCallbackID: callbackID,
|
|
Method: http.MethodPost,
|
|
URL: server.URL,
|
|
Body: []byte(`{"request":true}`),
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
if _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostHTTPDo, rawReq); errCall != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", errCall)
|
|
}
|
|
|
|
rawAPIRequest, okRequest := ginCtx.Get("API_REQUEST")
|
|
if !okRequest {
|
|
t.Fatal("API_REQUEST was not captured on the original Gin context")
|
|
}
|
|
apiRequest, _ := rawAPIRequest.([]byte)
|
|
if !bytes.Contains(apiRequest, []byte("=== API REQUEST 1 ===")) || !bytes.Contains(apiRequest, []byte(`{"request":true}`)) {
|
|
t.Fatalf("API_REQUEST = %q, want upstream request details", apiRequest)
|
|
}
|
|
|
|
rawAPIResponse, okResponse := ginCtx.Get("API_RESPONSE")
|
|
if !okResponse {
|
|
t.Fatal("API_RESPONSE was not captured on the original Gin context")
|
|
}
|
|
apiResponse, _ := rawAPIResponse.([]byte)
|
|
if !bytes.Contains(apiResponse, []byte("=== API RESPONSE 1 ===")) || !bytes.Contains(apiResponse, []byte("upstream-body")) {
|
|
t.Fatalf("API_RESPONSE = %q, want upstream response details", apiResponse)
|
|
}
|
|
}
|
|
|
|
func TestHostHTTPDoStreamCallbackReturnsBeforeUpstreamCompletes(t *testing.T) {
|
|
release := make(chan struct{})
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("first"))
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
<-release
|
|
_, _ = w.Write([]byte("second"))
|
|
}))
|
|
defer server.Close()
|
|
defer close(release)
|
|
|
|
rawReq, errMarshal := json.Marshal(pluginapi.HTTPRequest{
|
|
Method: http.MethodGet,
|
|
URL: server.URL,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
|
|
type callResult struct {
|
|
raw []byte
|
|
err error
|
|
}
|
|
done := make(chan callResult, 1)
|
|
host := New()
|
|
go func() {
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostHTTPDoStream, rawReq)
|
|
done <- callResult{raw: rawResp, err: errCall}
|
|
}()
|
|
|
|
var result callResult
|
|
select {
|
|
case result = <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("host.http.do_stream waited for the whole upstream response")
|
|
}
|
|
if result.err != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", result.err)
|
|
}
|
|
|
|
resp, errDecode := decodeRPCEnvelope[rpcHostHTTPStreamResponse](result.raw)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode response: %v", errDecode)
|
|
}
|
|
if resp.StreamID == "" {
|
|
t.Fatalf("stream id is empty: %#v", resp)
|
|
}
|
|
readReq, errMarshal := json.Marshal(rpcHostHTTPStreamReadRequest{StreamID: resp.StreamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal read request: %v", errMarshal)
|
|
}
|
|
rawRead, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostHTTPStreamRead, readReq)
|
|
if errRead != nil {
|
|
t.Fatalf("read callback error = %v", errRead)
|
|
}
|
|
chunk, errDecode := decodeRPCEnvelope[rpcHostHTTPStreamReadResponse](rawRead)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode read response: %v", errDecode)
|
|
}
|
|
if string(chunk.Payload) != "first" || chunk.Done || chunk.Error != "" {
|
|
t.Fatalf("read chunk = %#v, want first payload", chunk)
|
|
}
|
|
|
|
closeReq, errMarshal := json.Marshal(rpcHostHTTPStreamCloseRequest{StreamID: resp.StreamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal close request: %v", errMarshal)
|
|
}
|
|
if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostHTTPStreamClose, closeReq); errClose != nil {
|
|
t.Fatalf("close callback error = %v", errClose)
|
|
}
|
|
}
|
|
|
|
func TestHostStreamCallbacksEmitAndClose(t *testing.T) {
|
|
host := New()
|
|
streamID, chunks, cleanup := host.streams.open(context.Background())
|
|
defer cleanup()
|
|
|
|
emitReq, errMarshal := json.Marshal(rpcStreamEmitRequest{StreamID: streamID, Payload: []byte("chunk")})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal emit request: %v", errMarshal)
|
|
}
|
|
if _, errEmit := host.callFromPlugin(context.Background(), pluginabi.MethodHostStreamEmit, emitReq); errEmit != nil {
|
|
t.Fatalf("emit callback error = %v", errEmit)
|
|
}
|
|
|
|
closeReq, errMarshal := json.Marshal(rpcStreamCloseRequest{StreamID: streamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal close request: %v", errMarshal)
|
|
}
|
|
if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostStreamClose, closeReq); errClose != nil {
|
|
t.Fatalf("close callback error = %v", errClose)
|
|
}
|
|
|
|
chunk, ok := <-chunks
|
|
if !ok {
|
|
t.Fatalf("stream closed before chunk")
|
|
}
|
|
if string(chunk.Payload) != "chunk" || chunk.Err != nil {
|
|
t.Fatalf("chunk = %#v, want payload chunk", chunk)
|
|
}
|
|
if _, ok = <-chunks; ok {
|
|
t.Fatalf("stream remains open after close")
|
|
}
|
|
}
|
|
|
|
func TestHostModelExecuteCallback(t *testing.T) {
|
|
host := New()
|
|
var got handlers.ModelExecutionRequest
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModel: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) {
|
|
got = req
|
|
return handlers.ModelExecutionResponse{
|
|
StatusCode: http.StatusAccepted,
|
|
Headers: http.Header{"X-Model": []string{"ok"}},
|
|
Body: []byte(`{"response":true}`),
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{
|
|
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "claude",
|
|
Model: "model-1",
|
|
Body: []byte(`{"request":true}`),
|
|
Headers: http.Header{"X-Request": []string{"yes"}},
|
|
Query: url.Values{"alt": []string{"sse"}},
|
|
Alt: "raw",
|
|
},
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawReq)
|
|
if errCall != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", errCall)
|
|
}
|
|
|
|
resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelExecutionResponse](rawResp)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode response: %v", errDecode)
|
|
}
|
|
if resp.StatusCode != http.StatusAccepted || string(resp.Body) != `{"response":true}` {
|
|
t.Fatalf("response = %#v, want accepted body", resp)
|
|
}
|
|
if resp.Headers.Get("X-Model") != "ok" {
|
|
t.Fatalf("X-Model = %q, want ok", resp.Headers.Get("X-Model"))
|
|
}
|
|
if got.EntryProtocol != "openai" || got.ExitProtocol != "claude" || got.Model != "model-1" || got.Stream {
|
|
t.Fatalf("request protocols/model/stream = %#v", got)
|
|
}
|
|
if string(got.Body) != `{"request":true}` {
|
|
t.Fatalf("request body = %q, want original body", got.Body)
|
|
}
|
|
if got.Headers.Get("X-Request") != "yes" {
|
|
t.Fatalf("request header = %q, want yes", got.Headers.Get("X-Request"))
|
|
}
|
|
if got.Query.Get("alt") != "sse" {
|
|
t.Fatalf("query alt = %q, want sse", got.Query.Get("alt"))
|
|
}
|
|
if got.Alt != "raw" {
|
|
t.Fatalf("alt = %q, want raw", got.Alt)
|
|
}
|
|
}
|
|
|
|
func TestHostModelExecuteCallbackCarriesCallerPluginSkipID(t *testing.T) {
|
|
host := New()
|
|
var got handlers.ModelExecutionRequest
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModel: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) {
|
|
got = req
|
|
return handlers.ModelExecutionResponse{StatusCode: http.StatusOK, Body: []byte(`{"ok":true}`)}, nil
|
|
},
|
|
})
|
|
callbackID, closeCallback := host.openCallbackContextForPlugin(context.Background(), "origin-plugin")
|
|
defer closeCallback()
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{
|
|
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Body: []byte(`{"request":true}`),
|
|
},
|
|
HostCallbackID: callbackID,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
if _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawReq); errCall != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", errCall)
|
|
}
|
|
if got.SkipInterceptorPluginID != "origin-plugin" {
|
|
t.Fatalf("SkipInterceptorPluginID = %q, want origin-plugin", got.SkipInterceptorPluginID)
|
|
}
|
|
}
|
|
|
|
func TestHostModelStreamClosesWithCallbackScope(t *testing.T) {
|
|
host := New()
|
|
ctxSeen := make(chan context.Context, 1)
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
ctxSeen <- ctx
|
|
return handlers.ModelExecutionStream{
|
|
StatusCode: http.StatusOK,
|
|
Headers: http.Header{"X-Stream": []string{"ok"}},
|
|
Chunks: make(chan handlers.ModelExecutionChunk),
|
|
}, nil
|
|
},
|
|
})
|
|
callbackID, closeCallback := host.openCallbackContext(context.Background())
|
|
defer closeCallback()
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{
|
|
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
Body: []byte(`{"stream":true}`),
|
|
},
|
|
HostCallbackID: callbackID,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq)
|
|
if errCall != nil {
|
|
t.Fatalf("callFromPlugin() error = %v", errCall)
|
|
}
|
|
resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode response: %v", errDecode)
|
|
}
|
|
if resp.StreamID == "" {
|
|
t.Fatalf("stream id is empty: %#v", resp)
|
|
}
|
|
|
|
var streamCtx context.Context
|
|
select {
|
|
case streamCtx = <-ctxSeen:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("model executor was not called")
|
|
}
|
|
closeCallback()
|
|
select {
|
|
case <-streamCtx.Done():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("stream context was not canceled after callback scope closed")
|
|
}
|
|
}
|
|
|
|
func TestHostModelStreamReadAfterCallbackCloseReturnsDone(t *testing.T) {
|
|
host := New()
|
|
chunks := make(chan handlers.ModelExecutionChunk)
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
return handlers.ModelExecutionStream{
|
|
StatusCode: http.StatusOK,
|
|
Chunks: chunks,
|
|
}, nil
|
|
},
|
|
})
|
|
callbackID, closeCallback := host.openCallbackContext(context.Background())
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{
|
|
HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
Body: []byte(`{"stream":true}`),
|
|
},
|
|
HostCallbackID: callbackID,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq)
|
|
if errCall != nil {
|
|
t.Fatalf("execute stream callback error = %v", errCall)
|
|
}
|
|
resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode stream response: %v", errDecode)
|
|
}
|
|
if resp.StreamID == "" {
|
|
t.Fatalf("stream id is empty: %#v", resp)
|
|
}
|
|
|
|
closeCallback()
|
|
readReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{StreamID: resp.StreamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal read request: %v", errMarshal)
|
|
}
|
|
readDone := make(chan pluginapi.HostModelStreamReadResponse, 1)
|
|
readErr := make(chan error, 1)
|
|
go func() {
|
|
rawRead, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq)
|
|
if errRead != nil {
|
|
readErr <- errRead
|
|
return
|
|
}
|
|
doneResp, errDecodeRead := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead)
|
|
if errDecodeRead != nil {
|
|
readErr <- errDecodeRead
|
|
return
|
|
}
|
|
readDone <- doneResp
|
|
}()
|
|
select {
|
|
case errRead := <-readErr:
|
|
t.Fatalf("read after callback close error = %v", errRead)
|
|
case doneResp := <-readDone:
|
|
if !doneResp.Done || len(doneResp.Payload) != 0 || doneResp.Error != "" {
|
|
t.Fatalf("read after callback close = %#v, want done without payload/error", doneResp)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("read after callback close blocked")
|
|
}
|
|
}
|
|
|
|
func TestHostModelExecuteStreamStartupErrorCleansUp(t *testing.T) {
|
|
host := New()
|
|
ctxSeen := make(chan context.Context, 1)
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
ctxSeen <- ctx
|
|
return handlers.ModelExecutionStream{}, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadGateway,
|
|
}
|
|
},
|
|
})
|
|
|
|
rawReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
Body: []byte(`{"stream":true}`),
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq)
|
|
if errCall == nil {
|
|
t.Fatalf("execute stream callback error is nil, raw response = %q", rawResp)
|
|
}
|
|
if rawResp != nil {
|
|
t.Fatalf("raw response = %q, want nil on startup error", rawResp)
|
|
}
|
|
if !strings.Contains(errCall.Error(), "status 502") {
|
|
t.Fatalf("execute stream callback error = %v, want status 502", errCall)
|
|
}
|
|
|
|
var streamCtx context.Context
|
|
select {
|
|
case streamCtx = <-ctxSeen:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("model executor was not called")
|
|
}
|
|
select {
|
|
case <-streamCtx.Done():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("stream context was not canceled after startup error")
|
|
}
|
|
gotCount := hostModelStreamCountForTest(t, host)
|
|
if gotCount != 0 {
|
|
t.Fatalf("model stream count = %d, want 0", gotCount)
|
|
}
|
|
}
|
|
|
|
func TestHostModelCallbacksValidateStreamMode(t *testing.T) {
|
|
host := New()
|
|
|
|
rawExecuteReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal execute request: %v", errMarshal)
|
|
}
|
|
_, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawExecuteReq)
|
|
if errCall == nil || !strings.Contains(errCall.Error(), "host.model.execute requires stream=false") {
|
|
t.Fatalf("execute callback error = %v, want stream=false validation error", errCall)
|
|
}
|
|
|
|
rawStreamReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: false,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal execute stream request: %v", errMarshal)
|
|
}
|
|
_, errCall = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawStreamReq)
|
|
if errCall == nil || !strings.Contains(errCall.Error(), "host.model.execute_stream requires stream=true") {
|
|
t.Fatalf("execute stream callback error = %v, want stream=true validation error", errCall)
|
|
}
|
|
}
|
|
|
|
func TestHostModelCallbacksRequireExecutor(t *testing.T) {
|
|
host := New()
|
|
|
|
rawExecuteReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal execute request: %v", errMarshal)
|
|
}
|
|
_, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawExecuteReq)
|
|
if errCall == nil || !strings.Contains(errCall.Error(), "host model executor is unavailable") {
|
|
t.Fatalf("execute callback error = %v, want unavailable executor error", errCall)
|
|
}
|
|
|
|
rawStreamReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal execute stream request: %v", errMarshal)
|
|
}
|
|
_, errCall = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawStreamReq)
|
|
if errCall == nil || !strings.Contains(errCall.Error(), "host model executor is unavailable") {
|
|
t.Fatalf("execute stream callback error = %v, want unavailable executor error", errCall)
|
|
}
|
|
}
|
|
|
|
func TestHostModelStreamReadAndCloseValidateStreamID(t *testing.T) {
|
|
host := New()
|
|
|
|
rawReadReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal read request: %v", errMarshal)
|
|
}
|
|
_, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, rawReadReq)
|
|
if errRead == nil || !strings.Contains(errRead.Error(), "model stream id is required") {
|
|
t.Fatalf("read callback error = %v, want required stream id error", errRead)
|
|
}
|
|
|
|
rawCloseReq, errMarshal := json.Marshal(pluginapi.HostModelStreamCloseRequest{})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal close request: %v", errMarshal)
|
|
}
|
|
rawClose, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, rawCloseReq)
|
|
if errClose != nil {
|
|
t.Fatalf("close callback error = %v", errClose)
|
|
}
|
|
_, errDecode := decodeRPCEnvelope[rpcEmptyResponse](rawClose)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode close response: %v", errDecode)
|
|
}
|
|
}
|
|
|
|
func TestHostModelStreamReadReturnsPayloadAndTerminalError(t *testing.T) {
|
|
host := New()
|
|
chunks := make(chan handlers.ModelExecutionChunk, 2)
|
|
chunks <- handlers.ModelExecutionChunk{Payload: []byte("first")}
|
|
chunks <- handlers.ModelExecutionChunk{Err: &handlers.ModelExecutionStreamError{
|
|
StatusCode: http.StatusBadGateway,
|
|
Message: "terminal boom",
|
|
}}
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
return handlers.ModelExecutionStream{
|
|
StatusCode: http.StatusOK,
|
|
Headers: http.Header{"X-Stream": []string{"ok"}},
|
|
Chunks: chunks,
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
streamID := openHostModelStreamForTest(t, host)
|
|
readReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{StreamID: streamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal read request: %v", errMarshal)
|
|
}
|
|
rawRead, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq)
|
|
if errRead != nil {
|
|
t.Fatalf("read callback error = %v", errRead)
|
|
}
|
|
first, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode read response: %v", errDecode)
|
|
}
|
|
if string(first.Payload) != "first" || first.Done || first.Error != "" {
|
|
t.Fatalf("first read = %#v, want payload without done", first)
|
|
}
|
|
|
|
rawRead, errRead = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq)
|
|
if errRead != nil {
|
|
t.Fatalf("terminal read callback error = %v", errRead)
|
|
}
|
|
terminal, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode terminal response: %v", errDecode)
|
|
}
|
|
if !terminal.Done || terminal.Error != "terminal boom" || len(terminal.Payload) != 0 {
|
|
t.Fatalf("terminal read = %#v, want done terminal error", terminal)
|
|
}
|
|
}
|
|
|
|
func TestHostModelStreamExplicitCloseCancelsStream(t *testing.T) {
|
|
host := New()
|
|
ctxSeen := make(chan context.Context, 1)
|
|
host.SetModelExecutor(&fakeHostModelExecutor{
|
|
executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) {
|
|
ctxSeen <- ctx
|
|
return handlers.ModelExecutionStream{
|
|
StatusCode: http.StatusOK,
|
|
Chunks: make(chan handlers.ModelExecutionChunk),
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
streamID := openHostModelStreamForTest(t, host)
|
|
var streamCtx context.Context
|
|
select {
|
|
case streamCtx = <-ctxSeen:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("model executor was not called")
|
|
}
|
|
closeReq, errMarshal := json.Marshal(pluginapi.HostModelStreamCloseRequest{StreamID: streamID})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal close request: %v", errMarshal)
|
|
}
|
|
if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, closeReq); errClose != nil {
|
|
t.Fatalf("close callback error = %v", errClose)
|
|
}
|
|
select {
|
|
case <-streamCtx.Done():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("stream context was not canceled after explicit close")
|
|
}
|
|
if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, closeReq); errClose != nil {
|
|
t.Fatalf("second close callback error = %v", errClose)
|
|
}
|
|
}
|
|
|
|
func openHostModelStreamForTest(t *testing.T, host *Host) string {
|
|
t.Helper()
|
|
rawReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{
|
|
EntryProtocol: "openai",
|
|
ExitProtocol: "openai",
|
|
Model: "model-1",
|
|
Stream: true,
|
|
Body: []byte(`{"stream":true}`),
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal request: %v", errMarshal)
|
|
}
|
|
rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq)
|
|
if errCall != nil {
|
|
t.Fatalf("execute stream callback error = %v", errCall)
|
|
}
|
|
resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp)
|
|
if errDecode != nil {
|
|
t.Fatalf("decode stream response: %v", errDecode)
|
|
}
|
|
if resp.StreamID == "" {
|
|
t.Fatalf("stream id is empty: %#v", resp)
|
|
}
|
|
return resp.StreamID
|
|
}
|
|
|
|
func hostModelStreamCountForTest(t *testing.T, host *Host) int {
|
|
t.Helper()
|
|
host.modelStreams.mu.Lock()
|
|
defer host.modelStreams.mu.Unlock()
|
|
return len(host.modelStreams.streams)
|
|
}
|
|
|
|
func TestHostLogCallbackRestoresRegisteredRequestContext(t *testing.T) {
|
|
host := New()
|
|
ctx := logging.WithRequestID(context.Background(), "request-123")
|
|
callbackID, closeCallback := host.openCallbackContext(ctx)
|
|
defer closeCallback()
|
|
|
|
var out bytes.Buffer
|
|
logger := log.StandardLogger()
|
|
originalOut := logger.Out
|
|
originalFormatter := logger.Formatter
|
|
originalLevel := logger.Level
|
|
log.SetOutput(&out)
|
|
log.SetFormatter(&log.TextFormatter{
|
|
DisableColors: true,
|
|
DisableTimestamp: true,
|
|
})
|
|
log.SetLevel(log.InfoLevel)
|
|
defer func() {
|
|
log.SetOutput(originalOut)
|
|
log.SetFormatter(originalFormatter)
|
|
log.SetLevel(originalLevel)
|
|
}()
|
|
|
|
rawReq, errMarshal := json.Marshal(rpcHostLogRequest{
|
|
HostCallbackID: callbackID,
|
|
Level: "info",
|
|
Message: "plugin callback message",
|
|
})
|
|
if errMarshal != nil {
|
|
t.Fatalf("marshal log request: %v", errMarshal)
|
|
}
|
|
if _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostLog, rawReq); errCall != nil {
|
|
t.Fatalf("log callback error = %v", errCall)
|
|
}
|
|
|
|
got := out.String()
|
|
if !strings.Contains(got, "plugin callback message") || !strings.Contains(got, "request_id=request-123") {
|
|
t.Fatalf("log output = %q, want message and request_id field", got)
|
|
}
|
|
}
|