From 8e39db2ec7891d9831f61a96437b3ad142558882 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 12 Jun 2026 02:22:23 +0800 Subject: [PATCH] feat(plugin, api): introduce host model callback support with Go example and API handlers - Added an example plugin `host-model-callback` in Go to summarize host model callbacks. - Implemented `cliproxy_plugin_init`, `cliproxyPluginCall`, and other plugin functions for callback handling. - Introduced API handlers for `ModelExecution` and `ModelExecutionStream` with support for both streaming and non-streaming requests. - Included unit tests (`model_execution_test.go`) to validate execution logic and streaming responses. --- examples/plugin/README.md | 19 +- examples/plugin/README_CN.md | 19 +- examples/plugin/host-model-callback/README.md | 132 ++++ examples/plugin/host-model-callback/go/go.mod | 7 + .../plugin/host-model-callback/go/main.go | 725 ++++++++++++++++++ internal/api/server.go | 6 + internal/pluginhost/adapters.go | 54 +- internal/pluginhost/adapters_test.go | 153 +++- internal/pluginhost/callback_contexts.go | 49 +- internal/pluginhost/host.go | 29 + internal/pluginhost/host_callbacks.go | 141 ++++ internal/pluginhost/host_callbacks_test.go | 458 +++++++++++ internal/pluginhost/host_test.go | 35 + internal/pluginhost/model_stream_bridge.go | 91 +++ internal/pluginhost/rpc_client.go | 7 +- internal/pluginhost/rpc_schema.go | 5 + .../runtime/executor/aistudio_executor.go | 11 +- .../runtime/executor/antigravity_executor.go | 14 +- internal/runtime/executor/claude_executor.go | 13 +- internal/runtime/executor/codex_executor.go | 12 +- .../executor/codex_websockets_executor.go | 6 +- .../runtime/executor/gemini_cli_executor.go | 15 +- internal/runtime/executor/gemini_executor.go | 11 +- .../executor/gemini_vertex_executor.go | 23 +- internal/runtime/executor/kimi_executor.go | 8 +- .../executor/openai_compat_executor.go | 11 +- internal/runtime/executor/xai_executor.go | 9 +- sdk/api/handlers/handlers.go | 40 +- sdk/api/handlers/model_execution.go | 252 ++++++ sdk/api/handlers/model_execution_test.go | 392 ++++++++++ sdk/cliproxy/executor/types.go | 11 + sdk/cliproxy/executor/types_test.go | 26 + sdk/pluginabi/types.go | 18 +- sdk/pluginabi/types_test.go | 12 + sdk/pluginapi/types.go | 62 ++ sdk/pluginapi/types_test.go | 149 ++++ 36 files changed, 2935 insertions(+), 90 deletions(-) create mode 100644 examples/plugin/host-model-callback/README.md create mode 100644 examples/plugin/host-model-callback/go/go.mod create mode 100644 examples/plugin/host-model-callback/go/main.go create mode 100644 internal/pluginhost/model_stream_bridge.go create mode 100644 sdk/api/handlers/model_execution.go create mode 100644 sdk/api/handlers/model_execution_test.go create mode 100644 sdk/cliproxy/executor/types_test.go diff --git a/examples/plugin/README.md b/examples/plugin/README.md index 8f4891031..663054a17 100644 --- a/examples/plugin/README.md +++ b/examples/plugin/README.md @@ -13,7 +13,7 @@ This directory contains standard dynamic library plugin examples for the CLIProx - `protocol-format/`: minimal executor focused on input/output format declarations. - `request-translator/`: request translation capability only. - `request-normalizer/`: request normalization capability only. -- `codex-service-tier/`: Go-only request normalizer that sets Codex `gpt-5.4` requests to the priority service tier when enabled. +- `codex-service-tier/`: Go-only request normalizer that sets Codex `gpt-5.5` requests to the priority service tier when enabled. - `scheduler/`: Go-only scheduler that can select a configured auth ID, delegate to a built-in scheduler, or deny picks. - `response-translator/`: response translation capability only. - `response-normalizer/`: response normalization capability only. @@ -22,12 +22,13 @@ This directory contains standard dynamic library plugin examples for the CLIProx - `cli/`: command-line capability only. - `management-api/`: Management API and resource capability only. - `host-callback/`: minimal plugin resource that demonstrates host callbacks. +- `host-model-callback/`: Go-only plugin resource that calls the host model execution callbacks. Most standard capability examples contain `go/`, `c/`, and `rust/` subdirectories. Specialized examples may provide only the implementation language they need. ## Codex Service Tier -`codex-service-tier` declares the request normalization capability. When `fast` is `true`, it sets `service_tier` to `priority` for requests where `req.ToFormat` is `codex` and `req.Model` is `gpt-5.4`. +`codex-service-tier` declares the request normalization capability. When `fast` is `true`, it sets `service_tier` to `priority` for requests where `req.ToFormat` is `codex` and `req.Model` is `gpt-5.5`. ```yaml plugins: @@ -38,6 +39,20 @@ plugins: fast: false ``` +## Host Model Callback + +`host-model-callback` declares the Management API capability and exposes a browser resource named `Host Model Callback`. The resource calls `host.model.execute` for non-streaming requests and `host.model.execute_stream` plus `host.model.stream_read` for streaming requests. It demonstrates explicit stream close with `host.model.stream_close` and an `implicit_close=true` option for RPC-scope host cleanup. + +```yaml +plugins: + configs: + host-model-callback: + enabled: true + priority: 1 +``` + +The default example model is `gpt-5.5`, but the request succeeds only when the current CPA model and auth configuration can route that model. + ## Scheduler `scheduler` declares the scheduler capability. It can select a configured auth ID from the candidate list, delegate to the built-in `fill-first` or `round-robin` scheduler, or reject picks when `deny` is `true`. diff --git a/examples/plugin/README_CN.md b/examples/plugin/README_CN.md index 304fdbf3c..de8507421 100644 --- a/examples/plugin/README_CN.md +++ b/examples/plugin/README_CN.md @@ -13,7 +13,7 @@ - `protocol-format/`:使用最小执行器重点演示输入和输出格式声明。 - `request-translator/`:只演示请求转换能力。 - `request-normalizer/`:只演示请求规整能力。 -- `codex-service-tier/`:仅 Go 实现的请求规整插件,启用后会将 Codex `gpt-5.4` 请求设置为 priority service tier。 +- `codex-service-tier/`:仅 Go 实现的请求规整插件,启用后会将 Codex `gpt-5.5` 请求设置为 priority service tier。 - `scheduler/`:仅 Go 实现的调度插件,可选择指定 auth ID、委托内置调度器或拒绝调度。 - `response-translator/`:只演示响应转换能力。 - `response-normalizer/`:只演示响应规整能力。 @@ -22,12 +22,13 @@ - `cli/`:只演示命令行扩展能力。 - `management-api/`:只演示 Management API 和资源扩展能力。 - `host-callback/`:使用最小插件资源演示宿主回调。 +- `host-model-callback/`:仅 Go 实现的插件资源,演示调用宿主模型执行回调。 多数标准能力示例都包含 `go/`、`c/` 和 `rust/` 三个子目录。专用示例可能只提供所需的实现语言。 ## Codex Service Tier -`codex-service-tier` 声明请求规整能力。当 `fast` 为 `true` 时,如果 `req.ToFormat` 为 `codex` 且 `req.Model` 为 `gpt-5.4`,它会将 `service_tier` 设置为 `priority`。 +`codex-service-tier` 声明请求规整能力。当 `fast` 为 `true` 时,如果 `req.ToFormat` 为 `codex` 且 `req.Model` 为 `gpt-5.5`,它会将 `service_tier` 设置为 `priority`。 ```yaml plugins: @@ -38,6 +39,20 @@ plugins: fast: false ``` +## Host Model Callback + +`host-model-callback` 声明 Management API 能力,并暴露名为 `Host Model Callback` 的浏览器资源。该资源在非流式请求中调用 `host.model.execute`,在流式请求中调用 `host.model.execute_stream` 和 `host.model.stream_read`。它演示了通过 `host.model.stream_close` 显式关闭流,也提供 `implicit_close=true` 用于演示 RPC 作用域结束时的宿主隐式清理。 + +```yaml +plugins: + configs: + host-model-callback: + enabled: true + priority: 1 +``` + +默认示例模型是 `gpt-5.5`,但请求能否成功取决于当前 CPA 模型和认证配置是否可以路由该模型。 + ## Scheduler `scheduler` 声明调度能力。它可以从候选列表中选择配置的 auth ID,委托内置的 `fill-first` 或 `round-robin` 调度器,或在 `deny` 为 `true` 时拒绝调度。 diff --git a/examples/plugin/host-model-callback/README.md b/examples/plugin/host-model-callback/README.md new file mode 100644 index 000000000..a69e27e3a --- /dev/null +++ b/examples/plugin/host-model-callback/README.md @@ -0,0 +1,132 @@ +# Host Model Callback Plugin + +This Go-only plugin demonstrates how a plugin-owned browser resource can call the host model execution callbacks instead of sending any external HTTP request itself. + +## Purpose and Scope + +The plugin registers a Management API resource named `Host Model Callback` at `/status`. CPA exposes it under: + +```text +/v0/resource/plugins/host-model-callback/status +``` + +The resource examples are query-based. The resource reads URL query parameters, builds an OpenAI-compatible chat request, and calls: + +- `host.model.execute` for non-streaming model execution. +- `host.model.execute_stream`, `host.model.stream_read`, and `host.model.stream_close` for streaming execution. + +This example is intentionally limited to host model callbacks. It does not implement an executor, translator, normalizer, auth provider, scheduler, or any direct outbound HTTP client. + +## Build + +From this directory: + +```bash +cd go +go build -buildmode=c-shared -o host-model-callback.dylib . +rm -f host-model-callback.dylib host-model-callback.h +``` + +Use the platform extension expected by your target system: + +- `.dylib` on macOS +- `.so` on Linux +- `.dll` on Windows + +## Configuration + +Build the dynamic library and place it under the configured plugin directory with a basename that matches the plugin ID. For example, `plugins/host-model-callback.dylib` maps to `plugins.configs.host-model-callback`. + +```yaml +plugins: + enabled: true + dir: "plugins" + configs: + host-model-callback: + enabled: true + priority: 1 +``` + +This plugin does not define plugin-specific configuration fields. + +## Resource URL Examples + +Non-streaming request with defaults: + +```text +http://localhost:8080/v0/resource/plugins/host-model-callback/status +``` + +Non-streaming request with explicit protocol and prompt: + +```text +http://localhost:8080/v0/resource/plugins/host-model-callback/status?entry_protocol=openai&exit_protocol=openai&model=gpt-5.5&prompt=Say%20hello%20in%20one%20sentence +``` + +Streaming request with explicit close: + +```text +http://localhost:8080/v0/resource/plugins/host-model-callback/status?stream=true&model=gpt-5.5&prompt=Write%20three%20short%20tokens +``` + +Streaming request that relies on RPC-scope implicit close: + +```text +http://localhost:8080/v0/resource/plugins/host-model-callback/status?stream=true&implicit_close=true +``` + +The default model ID is `gpt-5.5` to match the current nearby Codex example documentation and code. It is only an example model identifier; the request succeeds only when your CPA configuration can route that model. + +## Parameters + +- `entry_protocol`: inbound client protocol passed to the host model execution path. The default is `openai`. +- `exit_protocol`: target provider protocol passed to the host model execution path. The default is `openai`. +- `model`: model identifier passed in the host model execution request. The default is `gpt-5.5`; availability depends on the configured model registry and auth records. +- `stream`: boolean flag. The default is `false`; set `stream=true` to use `host.model.execute_stream`. +- `prompt`: text used to build the default OpenAI-compatible request body. +- `body`: optional JSON string in the URL query used as the raw model request body. When `body` is provided, it replaces the generated body. +- `alt`: optional alternate route or mode suffix passed through the host model request. +- `implicit_close`: streaming-only boolean flag. The default is `false`. + +The generated default body is OpenAI-compatible: + +```json +{ + "model": "gpt-5.5", + "stream": false, + "messages": [ + { + "role": "user", + "content": "Summarize host model callbacks in one short sentence." + } + ] +} +``` + +For example, a URL-encoded `body` query value can provide the raw OpenAI-compatible request: + +```text +http://localhost:8080/v0/resource/plugins/host-model-callback/status?body=%7B%22model%22%3A%22gpt-5.5%22%2C%22stream%22%3Afalse%2C%22messages%22%3A%5B%7B%22role%22%3A%22user%22%2C%22content%22%3A%22Say%20hello%20in%20one%20sentence%22%7D%5D%7D +``` + +## Stream Close Semantics + +By default, streaming mode explicitly closes the host-owned stream with `host.model.stream_close` through a deferred close call. This is the preferred pattern for plugins because it releases stream resources as soon as the plugin has finished reading. + +When `implicit_close=true` is set, the plugin intentionally skips the explicit close call. CPA injects `host_callback_id` into the `management.handle` request, and this example forwards that callback ID to `host.model.execute_stream` so the host can close the stream when the `management.handle` RPC callback scope returns. This mode exists only to demonstrate host cleanup behavior; normal plugin code should explicitly close streams it opens. + +## Billing and Usage + +The callback uses the existing CPA model executor path. Usage collection, request accounting, and billing metadata are handled by the same executor and usage reporter path as normal proxied requests. The callback layer does not bill twice and does not create an additional usage record by itself. + +## Error Handling and Troubleshooting + +The page displays the model status, response headers, body, stream chunks, close mode, and any callback error returned by the host envelope. + +Common issues: + +- `host model executor is unavailable`: the host model executor path is not initialized for this plugin callback context. +- `unsupported model` or provider-specific routing errors: the `model` value is not routable with the current CPA model/auth configuration. +- `host.model.execute requires stream=false`: non-stream execution was called with a streaming request. +- `host.model.execute_stream requires stream=true`: streaming execution was called without `stream=true`. +- Empty or partial stream output: inspect the page error section and host logs; upstream stream errors are returned through `host.model.stream_read`. diff --git a/examples/plugin/host-model-callback/go/go.mod b/examples/plugin/host-model-callback/go/go.mod new file mode 100644 index 000000000..95672b7e6 --- /dev/null +++ b/examples/plugin/host-model-callback/go/go.mod @@ -0,0 +1,7 @@ +module github.com/router-for-me/CLIProxyAPI/v7/examples/plugin/host-model-callback/go + +go 1.26.0 + +require github.com/router-for-me/CLIProxyAPI/v7 v7.0.0 + +replace github.com/router-for-me/CLIProxyAPI/v7 => ../../../.. diff --git a/examples/plugin/host-model-callback/go/main.go b/examples/plugin/host-model-callback/go/main.go new file mode 100644 index 000000000..76cb1ae3f --- /dev/null +++ b/examples/plugin/host-model-callback/go/main.go @@ -0,0 +1,725 @@ +package main + +/* +#include +#include + +typedef struct { + void* ptr; + size_t len; +} cliproxy_buffer; + +typedef int (*cliproxy_host_call_fn)(void*, const char*, const uint8_t*, size_t, cliproxy_buffer*); +typedef void (*cliproxy_host_free_fn)(void*, size_t); + +typedef struct { + uint32_t abi_version; + void* host_ctx; + cliproxy_host_call_fn call; + cliproxy_host_free_fn free_buffer; +} cliproxy_host_api; + +typedef int (*cliproxy_plugin_call_fn)(char*, uint8_t*, size_t, cliproxy_buffer*); +typedef void (*cliproxy_plugin_free_fn)(void*, size_t); +typedef void (*cliproxy_plugin_shutdown_fn)(void); + +typedef struct { + uint32_t abi_version; + cliproxy_plugin_call_fn call; + cliproxy_plugin_free_fn free_buffer; + cliproxy_plugin_shutdown_fn shutdown; +} cliproxy_plugin_api; + +extern int cliproxyPluginCall(char*, uint8_t*, size_t, cliproxy_buffer*); +extern void cliproxyPluginFree(void*, size_t); +extern void cliproxyPluginShutdown(void); + +static const cliproxy_host_api* stored_host; + +static void store_host_api(const cliproxy_host_api* host) { + stored_host = host; +} + +static int call_host_api(const char* method, const uint8_t* request, size_t request_len, cliproxy_buffer* response) { + if (stored_host == NULL || stored_host->call == NULL) { + return 1; + } + return stored_host->call(stored_host->host_ctx, method, request, request_len, response); +} + +static void free_host_buffer(void* ptr, size_t len) { + if (stored_host != NULL && stored_host->free_buffer != NULL && ptr != NULL) { + stored_host->free_buffer(ptr, len); + } +} +*/ +import "C" + +import ( + "bytes" + "encoding/json" + "fmt" + "html" + "net/http" + "net/url" + "strconv" + "strings" + "unsafe" + + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" +) + +const ( + defaultModel = "gpt-5.5" + defaultPrompt = "Summarize host model callbacks in one short sentence." + pluginName = "host-model-callback" + resourcePath = "/status" + resourceContentType = "text/html; charset=utf-8" +) + +type envelope struct { + OK bool `json:"ok"` + Result json.RawMessage `json:"result,omitempty"` + Error *envelopeError `json:"error,omitempty"` +} + +type envelopeError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type registration struct { + SchemaVersion uint32 `json:"schema_version"` + Metadata pluginapi.Metadata `json:"metadata"` + Capabilities registrationCapabilities `json:"capabilities"` +} + +type registrationCapabilities struct { + ManagementAPI bool `json:"management_api"` +} + +type managementRegistration struct { + Resources []managementResource `json:"resources,omitempty"` +} + +type managementResource struct { + Path string `json:"Path"` + Menu string `json:"Menu"` + Description string `json:"Description"` +} + +type managementRequest struct { + Method string + Path string + Headers http.Header + Query url.Values + Body []byte + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type managementResponse struct { + StatusCode int `json:"StatusCode"` + Headers http.Header `json:"Headers"` + Body []byte `json:"Body"` +} + +type managementBodyOptions struct { + Model string `json:"model"` + Mode string `json:"mode"` + EntryProtocol string `json:"entry_protocol"` + ExitProtocol string `json:"exit_protocol"` + Prompt string `json:"prompt"` + Stream *bool `json:"stream"` + Body json.RawMessage `json:"body"` + Headers http.Header `json:"headers"` + Query url.Values `json:"query"` + Alt string `json:"alt"` + ImplicitClose *bool `json:"implicit_close"` +} + +type hostModelExecutionRequest struct { + pluginapi.HostModelExecutionRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + +type runOptions struct { + Model string + Mode string + EntryProtocol string + ExitProtocol string + Prompt string + Stream bool + Body []byte + Headers http.Header + Query url.Values + Alt string + ImplicitClose bool + HostCallbackID string +} + +type chatCompletionRequest struct { + Model string `json:"model"` + Stream bool `json:"stream"` + Messages []chatMessage `json:"messages"` +} + +type chatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type streamPageData struct { + StatusCode int + Headers http.Header + StreamID string + Chunks []string + Error string + CloseMode string + CloseError string +} + +func main() {} + +//export cliproxy_plugin_init +func cliproxy_plugin_init(host *C.cliproxy_host_api, plugin *C.cliproxy_plugin_api) C.int { + if plugin == nil { + return 1 + } + C.store_host_api(host) + plugin.abi_version = C.uint32_t(pluginabi.ABIVersion) + plugin.call = C.cliproxy_plugin_call_fn(C.cliproxyPluginCall) + plugin.free_buffer = C.cliproxy_plugin_free_fn(C.cliproxyPluginFree) + plugin.shutdown = C.cliproxy_plugin_shutdown_fn(C.cliproxyPluginShutdown) + return 0 +} + +//export cliproxyPluginCall +func cliproxyPluginCall(method *C.char, request *C.uint8_t, requestLen C.size_t, response *C.cliproxy_buffer) C.int { + if response != nil { + response.ptr = nil + response.len = 0 + } + if method == nil { + writeResponse(response, errorEnvelope("invalid_method", "method is required")) + return 1 + } + var requestBytes []byte + if request != nil && requestLen > 0 { + requestBytes = C.GoBytes(unsafe.Pointer(request), C.int(requestLen)) + } + raw, errHandle := handleMethod(C.GoString(method), requestBytes) + if errHandle != nil { + writeResponse(response, errorEnvelope("plugin_error", errHandle.Error())) + return 1 + } + writeResponse(response, raw) + return 0 +} + +//export cliproxyPluginFree +func cliproxyPluginFree(ptr unsafe.Pointer, len C.size_t) { + if ptr != nil { + C.free(ptr) + } + _ = len +} + +//export cliproxyPluginShutdown +func cliproxyPluginShutdown() {} + +func handleMethod(method string, request []byte) ([]byte, error) { + switch method { + case pluginabi.MethodPluginRegister, pluginabi.MethodPluginReconfigure: + return okEnvelope(pluginRegistration()) + case pluginabi.MethodManagementRegister: + return okEnvelope(managementRegistration{ + Resources: []managementResource{{ + Path: resourcePath, + Menu: "Host Model Callback", + Description: "Runs a model request through host.model callbacks and displays the result.", + }}, + }) + case pluginabi.MethodManagementHandle: + return handleManagement(request) + default: + return errorEnvelope("unknown_method", "unknown method: "+method), nil + } +} + +func pluginRegistration() registration { + return registration{ + SchemaVersion: pluginabi.SchemaVersion, + Metadata: pluginapi.Metadata{ + Name: pluginName, + Version: "0.1.0", + Author: "router-for-me", + GitHubRepository: "https://github.com/router-for-me/CLIProxyAPI", + Logo: "https://raw.githubusercontent.com/router-for-me/CLIProxyAPI/main/docs/logo.png", + ConfigFields: []pluginapi.ConfigField{}, + }, + Capabilities: registrationCapabilities{ + ManagementAPI: true, + }, + } +} + +func handleManagement(raw []byte) ([]byte, error) { + var req managementRequest + if len(raw) > 0 { + if errUnmarshal := json.Unmarshal(raw, &req); errUnmarshal != nil { + return nil, fmt.Errorf("decode management request: %w", errUnmarshal) + } + } + opts, errOptions := optionsFromManagementRequest(req) + if errOptions != nil { + page := renderPage(opts, 0, nil, nil, nil, errOptions.Error(), "", "") + return okEnvelope(htmlResponse(http.StatusBadRequest, page)) + } + if opts.Stream { + data := executeStream(opts) + page := renderPage(opts, data.StatusCode, data.Headers, nil, data.Chunks, data.Error, data.CloseMode, data.CloseError) + return okEnvelope(htmlResponse(http.StatusOK, page)) + } + resp, errExecute := executeOnce(opts) + if errExecute != nil { + page := renderPage(opts, 0, nil, nil, nil, errExecute.Error(), "", "") + return okEnvelope(htmlResponse(http.StatusOK, page)) + } + page := renderPage(opts, resp.StatusCode, resp.Headers, resp.Body, nil, "", "", "") + return okEnvelope(htmlResponse(http.StatusOK, page)) +} + +func optionsFromManagementRequest(req managementRequest) (runOptions, error) { + opts := runOptions{ + Model: defaultModel, + Mode: "non-stream", + EntryProtocol: "openai", + ExitProtocol: "openai", + Prompt: defaultPrompt, + Headers: http.Header{}, + Query: url.Values{}, + } + opts.HostCallbackID = strings.TrimSpace(req.HostCallbackID) + if len(req.Body) > 0 { + if errApplyBody := applyBodyOptions(&opts, req.Body); errApplyBody != nil { + return opts, errApplyBody + } + } + if errApplyQuery := applyQueryOptions(&opts, req.Query); errApplyQuery != nil { + return opts, errApplyQuery + } + if opts.Stream { + opts.Mode = "stream" + } else { + opts.Mode = "non-stream" + } + return opts, nil +} + +func applyBodyOptions(opts *runOptions, raw []byte) error { + var bodyOpts managementBodyOptions + if errUnmarshal := json.Unmarshal(raw, &bodyOpts); errUnmarshal != nil { + return fmt.Errorf("decode JSON request body: %w", errUnmarshal) + } + if strings.TrimSpace(bodyOpts.Model) != "" { + opts.Model = strings.TrimSpace(bodyOpts.Model) + } + if strings.TrimSpace(bodyOpts.Mode) != "" { + applyMode(opts, bodyOpts.Mode) + } + if strings.TrimSpace(bodyOpts.EntryProtocol) != "" { + opts.EntryProtocol = strings.TrimSpace(bodyOpts.EntryProtocol) + } + if strings.TrimSpace(bodyOpts.ExitProtocol) != "" { + opts.ExitProtocol = strings.TrimSpace(bodyOpts.ExitProtocol) + } + if bodyOpts.Prompt != "" { + opts.Prompt = bodyOpts.Prompt + } + if bodyOpts.Stream != nil { + opts.Stream = *bodyOpts.Stream + } + if len(bodyOpts.Body) > 0 && string(bodyOpts.Body) != "null" { + if !json.Valid(bodyOpts.Body) { + return fmt.Errorf("body must be valid JSON") + } + opts.Body = append([]byte(nil), bodyOpts.Body...) + } + if bodyOpts.Headers != nil { + opts.Headers = cloneHeader(bodyOpts.Headers) + } + if bodyOpts.Query != nil { + opts.Query = cloneValues(bodyOpts.Query) + } + if bodyOpts.Alt != "" { + opts.Alt = bodyOpts.Alt + } + if bodyOpts.ImplicitClose != nil { + opts.ImplicitClose = *bodyOpts.ImplicitClose + } + return nil +} + +func applyQueryOptions(opts *runOptions, query url.Values) error { + if query == nil { + return nil + } + if raw := strings.TrimSpace(query.Get("model")); raw != "" { + opts.Model = raw + } + if raw := strings.TrimSpace(query.Get("mode")); raw != "" { + applyMode(opts, raw) + } + if raw := strings.TrimSpace(query.Get("entry_protocol")); raw != "" { + opts.EntryProtocol = raw + } + if raw := strings.TrimSpace(query.Get("exit_protocol")); raw != "" { + opts.ExitProtocol = raw + } + if raw := query.Get("prompt"); raw != "" { + opts.Prompt = raw + } + if raw := strings.TrimSpace(query.Get("body")); raw != "" { + body := []byte(raw) + if !json.Valid(body) { + return fmt.Errorf("query body must be valid JSON") + } + opts.Body = append([]byte(nil), body...) + } + if raw := strings.TrimSpace(query.Get("alt")); raw != "" { + opts.Alt = raw + } + if errStream := applyBoolQuery(query, "stream", &opts.Stream); errStream != nil { + return errStream + } + if errImplicitClose := applyBoolQuery(query, "implicit_close", &opts.ImplicitClose); errImplicitClose != nil { + return errImplicitClose + } + return nil +} + +func applyMode(opts *runOptions, mode string) { + normalized := strings.ToLower(strings.TrimSpace(mode)) + switch normalized { + case "stream", "streaming": + opts.Stream = true + case "non-stream", "non_stream", "nonstream", "sync": + opts.Stream = false + } +} + +func applyBoolQuery(query url.Values, key string, target *bool) error { + raw := strings.TrimSpace(query.Get(key)) + if raw == "" { + return nil + } + parsed, errParse := strconv.ParseBool(raw) + if errParse != nil { + return fmt.Errorf("%s must be a boolean: %w", key, errParse) + } + *target = parsed + return nil +} + +func executeOnce(opts runOptions) (pluginapi.HostModelExecutionResponse, error) { + body, errBody := modelRequestBody(opts) + if errBody != nil { + return pluginapi.HostModelExecutionResponse{}, errBody + } + result, errCall := callHost(pluginabi.MethodHostModelExecute, hostModelExecutionRequest{ + HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{ + EntryProtocol: opts.EntryProtocol, + ExitProtocol: opts.ExitProtocol, + Model: opts.Model, + Stream: false, + Body: body, + Headers: cloneHeader(opts.Headers), + Query: cloneValues(opts.Query), + Alt: opts.Alt, + }, + HostCallbackID: opts.HostCallbackID, + }) + if errCall != nil { + return pluginapi.HostModelExecutionResponse{}, errCall + } + var resp pluginapi.HostModelExecutionResponse + if errUnmarshal := json.Unmarshal(result, &resp); errUnmarshal != nil { + return pluginapi.HostModelExecutionResponse{}, fmt.Errorf("decode host.model.execute result: %w", errUnmarshal) + } + return resp, nil +} + +func executeStream(opts runOptions) (data streamPageData) { + body, errBody := modelRequestBody(opts) + if errBody != nil { + data.Error = errBody.Error() + return data + } + result, errCall := callHost(pluginabi.MethodHostModelExecuteStream, hostModelExecutionRequest{ + HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{ + EntryProtocol: opts.EntryProtocol, + ExitProtocol: opts.ExitProtocol, + Model: opts.Model, + Stream: true, + Body: body, + Headers: cloneHeader(opts.Headers), + Query: cloneValues(opts.Query), + Alt: opts.Alt, + }, + HostCallbackID: opts.HostCallbackID, + }) + if errCall != nil { + data.Error = errCall.Error() + return data + } + var resp pluginapi.HostModelStreamResponse + if errUnmarshal := json.Unmarshal(result, &resp); errUnmarshal != nil { + data.Error = fmt.Sprintf("decode host.model.execute_stream result: %v", errUnmarshal) + return data + } + data.StatusCode = resp.StatusCode + data.Headers = cloneHeader(resp.Headers) + data.StreamID = resp.StreamID + if resp.StreamID == "" { + data.Error = "host.model.execute_stream returned an empty stream_id" + return data + } + if opts.ImplicitClose { + // When implicit_close=true, the host closes this stream when the management.handle RPC callback scope returns. + data.CloseMode = "implicit close at management.handle return" + } else { + data.CloseMode = "explicit close through host.model.stream_close" + defer func() { + if errClose := closeHostModelStream(resp.StreamID); errClose != nil { + data.CloseError = errClose.Error() + } + }() + } + for { + chunk, errRead := readHostModelStream(resp.StreamID) + if errRead != nil { + data.Error = errRead.Error() + return data + } + if len(chunk.Payload) > 0 { + data.Chunks = append(data.Chunks, string(chunk.Payload)) + } + if chunk.Error != "" { + data.Error = chunk.Error + return data + } + if chunk.Done { + return data + } + } +} + +func readHostModelStream(streamID string) (pluginapi.HostModelStreamReadResponse, error) { + result, errCall := callHost(pluginabi.MethodHostModelStreamRead, pluginapi.HostModelStreamReadRequest{StreamID: streamID}) + if errCall != nil { + return pluginapi.HostModelStreamReadResponse{}, errCall + } + var resp pluginapi.HostModelStreamReadResponse + if errUnmarshal := json.Unmarshal(result, &resp); errUnmarshal != nil { + return pluginapi.HostModelStreamReadResponse{}, fmt.Errorf("decode host.model.stream_read result: %w", errUnmarshal) + } + return resp, nil +} + +func closeHostModelStream(streamID string) error { + _, errCall := callHost(pluginabi.MethodHostModelStreamClose, pluginapi.HostModelStreamCloseRequest{StreamID: streamID}) + return errCall +} + +func modelRequestBody(opts runOptions) ([]byte, error) { + if len(opts.Body) > 0 { + return append([]byte(nil), opts.Body...), nil + } + raw, errMarshal := json.Marshal(chatCompletionRequest{ + Model: opts.Model, + Stream: opts.Stream, + Messages: []chatMessage{{ + Role: "user", + Content: opts.Prompt, + }}, + }) + if errMarshal != nil { + return nil, fmt.Errorf("marshal OpenAI-compatible request body: %w", errMarshal) + } + return raw, nil +} + +func callHost(method string, payload any) (json.RawMessage, error) { + rawPayload, errMarshal := json.Marshal(payload) + if errMarshal != nil { + return nil, fmt.Errorf("marshal host callback payload %s: %w", method, errMarshal) + } + cMethod := C.CString(method) + defer C.free(unsafe.Pointer(cMethod)) + + var response C.cliproxy_buffer + var requestPtr *C.uint8_t + if len(rawPayload) > 0 { + cPayload := C.CBytes(rawPayload) + if cPayload == nil { + return nil, fmt.Errorf("allocate host callback payload %s", method) + } + defer C.free(cPayload) + requestPtr = (*C.uint8_t)(cPayload) + } + callCode := C.call_host_api(cMethod, requestPtr, C.size_t(len(rawPayload)), &response) + var rawResponse []byte + if response.ptr != nil && response.len > 0 { + rawResponse = C.GoBytes(response.ptr, C.int(response.len)) + } + if response.ptr != nil { + C.free_host_buffer(response.ptr, response.len) + } + if len(rawResponse) == 0 { + return nil, fmt.Errorf("host callback %s returned no response, code=%d", method, int(callCode)) + } + + var env envelope + if errUnmarshal := json.Unmarshal(rawResponse, &env); errUnmarshal != nil { + return nil, fmt.Errorf("decode host callback envelope %s: %w", method, errUnmarshal) + } + if !env.OK { + if env.Error != nil { + return nil, fmt.Errorf("%s: %s", env.Error.Code, env.Error.Message) + } + return nil, fmt.Errorf("host callback %s failed", method) + } + if callCode != 0 { + return nil, fmt.Errorf("host callback %s returned code=%d", method, int(callCode)) + } + return append(json.RawMessage(nil), env.Result...), nil +} + +func htmlResponse(statusCode int, body []byte) managementResponse { + return managementResponse{ + StatusCode: statusCode, + Headers: http.Header{ + "content-type": []string{resourceContentType}, + }, + Body: body, + } +} + +func renderPage(opts runOptions, status int, headers http.Header, body []byte, chunks []string, errText string, closeMode string, closeError string) []byte { + var out bytes.Buffer + out.WriteString("Host Model Callback") + out.WriteString("") + out.WriteString("
") + out.WriteString("

Host Model Callback

") + out.WriteString("
") + writeDefinition(&out, "model", opts.Model) + writeDefinition(&out, "mode", opts.Mode) + writeDefinition(&out, "entry_protocol", opts.EntryProtocol) + writeDefinition(&out, "exit_protocol", opts.ExitProtocol) + writeDefinition(&out, "stream", strconv.FormatBool(opts.Stream)) + writeDefinition(&out, "implicit_close", strconv.FormatBool(opts.ImplicitClose)) + if closeMode != "" { + writeDefinition(&out, "close", closeMode) + } + writeDefinition(&out, "status", strconv.Itoa(status)) + out.WriteString("
") + if errText != "" { + out.WriteString("

Error

")
+		out.WriteString(html.EscapeString(errText))
+		out.WriteString("
") + } + if closeError != "" { + out.WriteString("

Close Error

")
+		out.WriteString(html.EscapeString(closeError))
+		out.WriteString("
") + } + if headers != nil { + out.WriteString("

Headers

")
+		out.WriteString(html.EscapeString(prettyJSON(headers)))
+		out.WriteString("
") + } + if len(chunks) > 0 { + out.WriteString("

Stream Chunks

")
+		out.WriteString(html.EscapeString(strings.Join(chunks, "")))
+		out.WriteString("
") + } + if len(body) > 0 { + out.WriteString("

Body

")
+		out.WriteString(html.EscapeString(prettyBody(body)))
+		out.WriteString("
") + } + out.WriteString("
") + return out.Bytes() +} + +func writeDefinition(out *bytes.Buffer, key string, value string) { + out.WriteString("
") + out.WriteString(html.EscapeString(key)) + out.WriteString("
") + out.WriteString(html.EscapeString(value)) + out.WriteString("
") +} + +func prettyBody(raw []byte) string { + var buf bytes.Buffer + if errIndent := json.Indent(&buf, raw, "", " "); errIndent == nil { + return buf.String() + } + return string(raw) +} + +func prettyJSON(v any) string { + raw, errMarshal := json.MarshalIndent(v, "", " ") + if errMarshal != nil { + return fmt.Sprintf("%v", v) + } + return string(raw) +} + +func okEnvelope(v any) ([]byte, error) { + raw, errMarshal := json.Marshal(v) + if errMarshal != nil { + return nil, errMarshal + } + return json.Marshal(envelope{OK: true, Result: raw}) +} + +func errorEnvelope(code, message string) []byte { + raw, _ := json.Marshal(envelope{OK: false, Error: &envelopeError{Code: code, Message: message}}) + return raw +} + +func writeResponse(response *C.cliproxy_buffer, raw []byte) { + if response == nil || len(raw) == 0 { + return + } + ptr := C.CBytes(raw) + if ptr == nil { + return + } + response.ptr = ptr + response.len = C.size_t(len(raw)) +} + +func cloneHeader(headers http.Header) http.Header { + if headers == nil { + return nil + } + cloned := make(http.Header, len(headers)) + for key, values := range headers { + cloned[key] = append([]string(nil), values...) + } + return cloned +} + +func cloneValues(values url.Values) url.Values { + if values == nil { + return nil + } + cloned := make(url.Values, len(values)) + for key, items := range values { + cloned[key] = append([]string(nil), items...) + } + return cloned +} diff --git a/internal/api/server.go b/internal/api/server.go index d486c4f78..0c27bcb16 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -301,6 +301,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk } s.wsAuthEnabled.Store(cfg.WebsocketAuth) s.handlers.SetPluginHost(optionState.pluginHost) + if optionState.pluginHost != nil { + optionState.pluginHost.SetModelExecutor(s.handlers) + } // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) s.applyAccessConfig(nil, cfg) @@ -1586,6 +1589,9 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.handlers.UpdateClients(effectiveSDKConfig(cfg)) s.handlers.SetPluginHost(s.pluginHost) + if s.pluginHost != nil { + s.pluginHost.SetModelExecutor(s.handlers) + } if s.mgmt != nil { s.mgmt.SetConfig(cfg) diff --git a/internal/pluginhost/adapters.go b/internal/pluginhost/adapters.go index a5801e224..33ca53f34 100644 --- a/internal/pluginhost/adapters.go +++ b/internal/pluginhost/adapters.go @@ -1307,14 +1307,16 @@ func (a *executorAdapter) Identifier() string { type preparedExecutorCall struct { req coreexecutor.Request opts coreexecutor.Options + inputRequested sdktranslator.Format requestedFormat sdktranslator.Format inputFormat sdktranslator.Format outputFormat sdktranslator.Format } func (a *executorAdapter) prepareExecutorCall(req coreexecutor.Request, opts coreexecutor.Options) (preparedExecutorCall, error) { + inputRequested := executorInputFormat(req, opts) requestedFormat := executorRequestedFormat(req, opts) - inputFormat, errInput := a.selectExecutorInputFormat(requestedFormat) + inputFormat, errInput := a.selectExecutorInputFormat(inputRequested) if errInput != nil { return preparedExecutorCall{}, errInput } @@ -1325,15 +1327,17 @@ func (a *executorAdapter) prepareExecutorCall(req coreexecutor.Request, opts cor nativeReq := req nativeOpts := opts - if requestedFormat != "" && requestedFormat != inputFormat { - nativeReq.Payload = sdktranslator.TranslateRequest(requestedFormat, inputFormat, req.Model, req.Payload, opts.Stream) + if inputRequested != "" && inputRequested != inputFormat { + nativeReq.Payload = sdktranslator.TranslateRequest(inputRequested, inputFormat, req.Model, req.Payload, opts.Stream) } nativeReq.Format = outputFormat nativeOpts.SourceFormat = inputFormat + nativeOpts.ResponseFormat = outputFormat return preparedExecutorCall{ req: nativeReq, opts: nativeOpts, + inputRequested: inputRequested, requestedFormat: requestedFormat, inputFormat: inputFormat, outputFormat: outputFormat, @@ -1344,15 +1348,15 @@ func (a *executorAdapter) RequestToFormat(req coreexecutor.Request, opts coreexe if a == nil { return "" } - requestedFormat := executorRequestedFormat(req, opts) - inputFormat, errInput := a.selectExecutorInputFormat(requestedFormat) + inputRequested := executorInputFormat(req, opts) + inputFormat, errInput := a.selectExecutorInputFormat(inputRequested) if errInput != nil { return "" } return inputFormat } -func executorRequestedFormat(req coreexecutor.Request, opts coreexecutor.Options) sdktranslator.Format { +func executorInputFormat(req coreexecutor.Request, opts coreexecutor.Options) sdktranslator.Format { if opts.SourceFormat != "" { return normalizeExecutorFormatName(opts.SourceFormat.String()) } @@ -1362,6 +1366,16 @@ func executorRequestedFormat(req coreexecutor.Request, opts coreexecutor.Options return sdktranslator.FormatOpenAI } +func executorRequestedFormat(req coreexecutor.Request, opts coreexecutor.Options) sdktranslator.Format { + if format := coreexecutor.ResponseFormatOrSource(opts); format != "" { + return normalizeExecutorFormatName(format.String()) + } + if req.Format != "" { + return normalizeExecutorFormatName(req.Format.String()) + } + return sdktranslator.FormatOpenAI +} + func (a *executorAdapter) selectExecutorInputFormat(requested sdktranslator.Format) (sdktranslator.Format, error) { if len(a.inputFormats) == 0 { return "", fmt.Errorf("plugin executor %s declares no input formats", a.Identifier()) @@ -1384,18 +1398,38 @@ func (a *executorAdapter) selectExecutorOutputFormat(requested, inputFormat sdkt if executorFormatContains(a.outputFormats, requested) { return requested, nil } - if executorFormatContains(a.outputFormats, inputFormat) && executorResponseTranslatorExists(inputFormat, requested) { + if executorFormatContains(a.outputFormats, inputFormat) && a.executorResponseTranslationAvailable(inputFormat, requested) { return inputFormat, nil } for _, format := range a.outputFormats { - if requested == "" || executorResponseTranslatorExists(format, requested) { + if requested == "" || a.executorResponseTranslationAvailable(format, requested) { return format, nil } } return "", fmt.Errorf("plugin executor %s does not support output format %q", a.Identifier(), requested) } -func executorResponseTranslatorExists(from, to sdktranslator.Format) bool { +func (a *executorAdapter) executorResponseTranslationAvailable(from, to sdktranslator.Format) bool { + if from == "" || to == "" || from == to { + return true + } + if sdktranslator.HasResponseTransformer(to, from) { + return true + } + return a != nil && a.host.hasResponseTranslator() +} + +func (h *Host) hasResponseTranslator() bool { + for _, record := range h.Snapshot().records { + if h.isPluginFused(record.id) || record.plugin.Capabilities.ResponseTranslator == nil { + continue + } + return true + } + return false +} + +func executorNativeStreamResponseTranslatorExists(from, to sdktranslator.Format) bool { if from == "" || to == "" || from == to { return true } @@ -1484,7 +1518,7 @@ func executorStreamTranslationFellBack(prepared preparedExecutorCall, payload [] // A plugin executor only reaches this path after host-side response translation // has been selected. An unchanged single frame is the SDK registry fallback, // not a valid translated frame to send to the client. - return executorResponseTranslatorExists(prepared.outputFormat, prepared.requestedFormat) + return executorNativeStreamResponseTranslatorExists(prepared.outputFormat, prepared.requestedFormat) } func (a *executorAdapter) emitTranslatedExecutorStreamTail(ctx context.Context, prepared preparedExecutorCall, out chan<- pluginapi.ExecutorStreamChunk, param *any) { diff --git a/internal/pluginhost/adapters_test.go b/internal/pluginhost/adapters_test.go index a9db914ee..b5a5d8b3e 100644 --- a/internal/pluginhost/adapters_test.go +++ b/internal/pluginhost/adapters_test.go @@ -78,7 +78,7 @@ func TestPluginModelInfoToRegistryModelInfoClonesThinkingAndSlices(t *testing.T) } } -func TestExecutorResponseTranslatorExistsRequiresStreamTransform(t *testing.T) { +func TestExecutorNativeStreamResponseTranslatorExistsRequiresStreamTransform(t *testing.T) { outputFormat := sdktranslator.Format("plugin-output-non-stream-only") requestedFormat := sdktranslator.Format("client-output-non-stream-only") sdktranslator.Register(requestedFormat, outputFormat, nil, sdktranslator.ResponseTransform{ @@ -87,7 +87,7 @@ func TestExecutorResponseTranslatorExistsRequiresStreamTransform(t *testing.T) { }, }) - if executorResponseTranslatorExists(outputFormat, requestedFormat) { + if executorNativeStreamResponseTranslatorExists(outputFormat, requestedFormat) { t.Fatal("non-stream-only response transformer was accepted for stream executor output") } @@ -99,7 +99,7 @@ func TestExecutorResponseTranslatorExistsRequiresStreamTransform(t *testing.T) { }, }) - if !executorResponseTranslatorExists(streamOutputFormat, streamRequestedFormat) { + if !executorNativeStreamResponseTranslatorExists(streamOutputFormat, streamRequestedFormat) { t.Fatal("stream response transformer was not accepted for stream executor output") } } @@ -2684,6 +2684,112 @@ func TestExecutorAdapterMethods(t *testing.T) { } } +func TestExecutorAdapterUsesResponseFormatForOutputTranslation(t *testing.T) { + claudeResponse := []byte(`{"id":"msg_1","type":"message","model":"claude-test","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`) + openAIRequest := []byte(`{"model":"model-1","messages":[{"role":"user","content":"hi"}]}`) + + var captured pluginapi.ExecutorRequest + adapter := &executorAdapter{ + host: New(), + pluginID: "executor-plugin", + provider: "plugin-provider", + inputFormats: []sdktranslator.Format{sdktranslator.FormatClaude}, + outputFormats: []sdktranslator.Format{sdktranslator.FormatClaude}, + executor: &fakeExecutor{ + execute: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) { + captured = req + return pluginapi.ExecutorResponse{Payload: claudeResponse}, nil + }, + }, + } + + resp, errExecute := adapter.Execute(context.Background(), &coreauth.Auth{}, coreexecutor.Request{ + Model: "model-1", + Format: sdktranslator.FormatOpenAI, + Payload: openAIRequest, + }, coreexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAI, + ResponseFormat: sdktranslator.FormatClaude, + }) + if errExecute != nil { + t.Fatalf("Execute() error = %v", errExecute) + } + if captured.SourceFormat != sdktranslator.FormatClaude.String() { + t.Fatalf("executor SourceFormat = %q, want %q", captured.SourceFormat, sdktranslator.FormatClaude) + } + if captured.Format != sdktranslator.FormatClaude.String() { + t.Fatalf("executor Format = %q, want %q", captured.Format, sdktranslator.FormatClaude) + } + if bytes.Equal(captured.Payload, openAIRequest) || !bytes.Contains(captured.Payload, []byte(`"max_tokens":32000`)) { + t.Fatalf("executor payload = %s, want translated Claude request", captured.Payload) + } + if !bytes.Equal(resp.Payload, claudeResponse) { + t.Fatalf("Execute() payload = %s, want Claude response payload %s", resp.Payload, claudeResponse) + } +} + +func TestExecutorAdapterSelectsCustomOutputWithHostResponseTranslator(t *testing.T) { + customOutputFormat := sdktranslator.Format("plugin-custom-output") + requestedFormat := sdktranslator.FormatOpenAI + body := []byte("plugin-body") + translatedBody := []byte("translated-body") + var captured pluginapi.ResponseTransformRequest + + host := newHostWithRecords(capabilityRecord{ + id: "response-translator", + plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{ + ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) { + captured = req + return pluginapi.PayloadResponse{Body: translatedBody}, nil + }), + }}, + }) + sdktranslator.SetPluginHooks(host) + t.Cleanup(func() { + sdktranslator.SetPluginHooks(nil) + }) + + adapter := &executorAdapter{ + host: host, + pluginID: "executor-plugin", + provider: "plugin-provider", + inputFormats: []sdktranslator.Format{sdktranslator.FormatOpenAI}, + outputFormats: []sdktranslator.Format{customOutputFormat}, + executor: &fakeExecutor{ + execute: func(ctx context.Context, req pluginapi.ExecutorRequest) (pluginapi.ExecutorResponse, error) { + if req.Format != customOutputFormat.String() { + t.Fatalf("executor Format = %q, want %q", req.Format, customOutputFormat) + } + return pluginapi.ExecutorResponse{Payload: body}, nil + }, + }, + } + + resp, errExecute := adapter.Execute(context.Background(), &coreauth.Auth{}, coreexecutor.Request{ + Model: "model-1", + Format: sdktranslator.FormatOpenAI, + Payload: []byte(`{"model":"model-1"}`), + }, coreexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAI, + ResponseFormat: requestedFormat, + }) + if errExecute != nil { + t.Fatalf("Execute() error = %v", errExecute) + } + if !bytes.Equal(resp.Payload, translatedBody) { + t.Fatalf("Execute() payload = %q, want %q", resp.Payload, translatedBody) + } + if captured.FromFormat != customOutputFormat.String() || captured.ToFormat != requestedFormat.String() { + t.Fatalf("translator formats = %q -> %q, want %q -> %q", captured.FromFormat, captured.ToFormat, customOutputFormat, requestedFormat) + } + if captured.Stream { + t.Fatal("translator Stream = true, want false") + } + if !bytes.Equal(captured.Body, body) { + t.Fatalf("translator body = %q, want %q", captured.Body, body) + } +} + func TestExecutorAdapterConsumesTranslatedStreamChunksWithoutOutput(t *testing.T) { adapter := &executorAdapter{} request := []byte(`{"model":"qmodel_latest","stream":true,"tool_choice":"auto","parallel_tool_calls":true}`) @@ -2736,6 +2842,47 @@ func TestExecutorAdapterConsumesTranslatedStreamChunksWithoutOutput(t *testing.T } } +func TestExecutorAdapterKeepsRawStreamFallbackWithOnlyHostResponseTranslator(t *testing.T) { + customOutputFormat := sdktranslator.Format("plugin-custom-stream-output") + requestedFormat := sdktranslator.FormatOpenAI + payload := []byte(`{"custom":"chunk"}`) + host := newHostWithRecords(capabilityRecord{ + id: "empty-response-translator", + plugin: pluginapi.Plugin{Capabilities: pluginapi.Capabilities{ + ResponseTranslator: responseTranslatorFunc(func(ctx context.Context, req pluginapi.ResponseTransformRequest) (pluginapi.PayloadResponse, error) { + return pluginapi.PayloadResponse{}, nil + }), + }}, + }) + sdktranslator.SetPluginHooks(host) + t.Cleanup(func() { + sdktranslator.SetPluginHooks(nil) + }) + adapter := &executorAdapter{ + host: host, + } + prepared := preparedExecutorCall{ + req: coreexecutor.Request{ + Model: "model-1", + Payload: []byte(`{"model":"model-1"}`), + }, + opts: coreexecutor.Options{ + OriginalRequest: []byte(`{"model":"model-1","stream":true}`), + }, + requestedFormat: requestedFormat, + outputFormat: customOutputFormat, + } + var param any + + frames := adapter.translateExecutorStreamPayload(context.Background(), prepared, payload, ¶m) + if len(frames) != 1 { + t.Fatalf("translated stream frame count = %d, want 1", len(frames)) + } + if !bytes.Equal(frames[0], payload) { + t.Fatalf("translated stream frame = %q, want raw payload %q", frames[0], payload) + } +} + func TestExecutorAdapterPanicFusesAndReturnsError(t *testing.T) { host := New() calls := 0 diff --git a/internal/pluginhost/callback_contexts.go b/internal/pluginhost/callback_contexts.go index b3e07d9f1..b87e67ed6 100644 --- a/internal/pluginhost/callback_contexts.go +++ b/internal/pluginhost/callback_contexts.go @@ -10,11 +10,16 @@ import ( type callbackContextRegistry struct { next atomic.Uint64 mu sync.RWMutex - contexts map[string]context.Context + contexts map[string]callbackContextEntry +} + +type callbackContextEntry struct { + ctx context.Context + cleanup []func() } func newCallbackContextRegistry() *callbackContextRegistry { - return &callbackContextRegistry{contexts: make(map[string]context.Context)} + return &callbackContextRegistry{contexts: make(map[string]callbackContextEntry)} } func (r *callbackContextRegistry) open(ctx context.Context) (string, func()) { @@ -26,19 +31,45 @@ func (r *callbackContextRegistry) open(ctx context.Context) (string, func()) { } id := strconv.FormatUint(r.next.Add(1), 10) r.mu.Lock() - r.contexts[id] = ctx + r.contexts[id] = callbackContextEntry{ctx: ctx} r.mu.Unlock() var once sync.Once return id, func() { once.Do(func() { + var cleanup []func() r.mu.Lock() + entry := r.contexts[id] delete(r.contexts, id) r.mu.Unlock() + cleanup = entry.cleanup + for _, fn := range cleanup { + if fn != nil { + fn() + } + } }) } } +func (r *callbackContextRegistry) addCleanup(id string, cleanup func()) bool { + if r == nil || id == "" || cleanup == nil { + return false + } + r.mu.Lock() + entry, ok := r.contexts[id] + if ok { + entry.cleanup = append(entry.cleanup, cleanup) + r.contexts[id] = entry + } + r.mu.Unlock() + if !ok { + cleanup() + return false + } + return true +} + func (r *callbackContextRegistry) resolve(id string, fallback context.Context) context.Context { if fallback == nil { fallback = context.Background() @@ -47,7 +78,7 @@ func (r *callbackContextRegistry) resolve(id string, fallback context.Context) c return fallback } r.mu.RLock() - ctx := r.contexts[id] + ctx := r.contexts[id].ctx r.mu.RUnlock() if ctx == nil { return fallback @@ -62,6 +93,16 @@ func (h *Host) openCallbackContext(ctx context.Context) (string, func()) { return h.callbackContexts.open(ctx) } +func (h *Host) addCallbackCleanup(id string, cleanup func()) bool { + if h == nil || h.callbackContexts == nil { + if id != "" && cleanup != nil { + cleanup() + } + return false + } + return h.callbackContexts.addCleanup(id, cleanup) +} + func (h *Host) resolveCallbackContext(id string, fallback context.Context) context.Context { if h == nil || h.callbackContexts == nil { if fallback == nil { diff --git a/internal/pluginhost/host.go b/internal/pluginhost/host.go index 7469f4472..fefc5bd86 100644 --- a/internal/pluginhost/host.go +++ b/internal/pluginhost/host.go @@ -9,6 +9,8 @@ import ( "sync/atomic" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" log "github.com/sirupsen/logrus" @@ -21,12 +23,18 @@ type loadedPlugin struct { client pluginClient } +type modelExecutor interface { + ExecuteModel(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) + ExecuteModelStream(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) +} + type Host struct { mu sync.Mutex loader pluginLoader loaded map[string]*loadedPlugin fused map[string]string runtimeConfig *config.Config + modelExecutor modelExecutor modelClientIDs map[string]struct{} executorModelClientIDs map[string]struct{} modelProviders map[string]string @@ -40,6 +48,7 @@ type Host struct { resourceRoutes map[string]resourceRouteRecord streams *streamBridge httpStreams *hostHTTPStreamBridge + modelStreams *modelStreamBridge callbackContexts *callbackContextRegistry snapshot atomic.Value } @@ -62,6 +71,7 @@ func New() *Host { resourceRoutes: make(map[string]resourceRouteRecord), streams: newStreamBridge(), httpStreams: newHostHTTPStreamBridge(), + modelStreams: newModelStreamBridge(), callbackContexts: newCallbackContextRegistry(), } h.snapshot.Store(emptySnapshot()) @@ -74,6 +84,25 @@ func NewForTest(loader pluginLoader) *Host { return h } +func (h *Host) SetModelExecutor(executor modelExecutor) { + if h == nil { + return + } + h.mu.Lock() + h.modelExecutor = executor + h.mu.Unlock() +} + +func (h *Host) currentModelExecutor() modelExecutor { + if h == nil { + return nil + } + h.mu.Lock() + executor := h.modelExecutor + h.mu.Unlock() + return executor +} + func (h *Host) Snapshot() *Snapshot { if h == nil { return emptySnapshot() diff --git a/internal/pluginhost/host_callbacks.go b/internal/pluginhost/host_callbacks.go index ab76256b1..dd12ceb30 100644 --- a/internal/pluginhost/host_callbacks.go +++ b/internal/pluginhost/host_callbacks.go @@ -6,7 +6,9 @@ import ( "fmt" "strings" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" log "github.com/sirupsen/logrus" @@ -59,8 +61,21 @@ type rpcHostLogRequest struct { Fields map[string]any `json:"fields,omitempty"` } +type rpcHostModelExecutionRequest struct { + pluginapi.HostModelExecutionRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + func (h *Host) callFromPlugin(ctx context.Context, method string, request []byte) ([]byte, error) { switch method { + case pluginabi.MethodHostModelExecute: + return h.callHostModelExecute(ctx, request) + case pluginabi.MethodHostModelExecuteStream: + return h.callHostModelExecuteStream(ctx, request) + case pluginabi.MethodHostModelStreamRead: + return h.callHostModelStreamRead(ctx, request) + case pluginabi.MethodHostModelStreamClose: + return h.callHostModelStreamClose(request) case pluginabi.MethodHostHTTPDo: return h.callHostHTTPDo(ctx, request) case pluginabi.MethodHostHTTPDoStream: @@ -207,6 +222,132 @@ func (h *Host) callHostStreamClose(request []byte) ([]byte, error) { return marshalRPCResult(rpcEmptyResponse{}) } +func (h *Host) callHostModelExecute(ctx context.Context, request []byte) ([]byte, error) { + var req rpcHostModelExecutionRequest + if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil { + return nil, fmt.Errorf("decode host model execution request: %w", errUnmarshal) + } + if req.Stream { + return nil, fmt.Errorf("host.model.execute requires stream=false") + } + executor := h.currentModelExecutor() + if executor == nil { + return nil, fmt.Errorf("host model executor is unavailable") + } + ctx = h.resolveCallbackContext(req.HostCallbackID, ctx) + resp, errMsg := executor.ExecuteModel(ctx, modelExecutionRequestFromPlugin(req.HostModelExecutionRequest)) + if errMsg != nil { + return nil, modelExecutionError(errMsg) + } + return marshalRPCResult(pluginapi.HostModelExecutionResponse{ + StatusCode: resp.StatusCode, + Headers: cloneHeader(resp.Headers), + Body: append([]byte(nil), resp.Body...), + }) +} + +func (h *Host) callHostModelExecuteStream(ctx context.Context, request []byte) ([]byte, error) { + var req rpcHostModelExecutionRequest + if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil { + return nil, fmt.Errorf("decode host model execution stream request: %w", errUnmarshal) + } + if !req.Stream { + return nil, fmt.Errorf("host.model.execute_stream requires stream=true") + } + executor := h.currentModelExecutor() + if executor == nil { + return nil, fmt.Errorf("host model executor is unavailable") + } + ctx = h.resolveCallbackContext(req.HostCallbackID, ctx) + if ctx == nil { + ctx = context.Background() + } + streamCtx, cancel := context.WithCancel(ctx) + stream, errMsg := executor.ExecuteModelStream(streamCtx, modelExecutionRequestFromPlugin(req.HostModelExecutionRequest)) + if errMsg != nil { + cancel() + return nil, modelExecutionError(errMsg) + } + streamID := "" + if h != nil && h.modelStreams != nil { + streamID = h.modelStreams.open(req.HostCallbackID, stream.Chunks, cancel) + } + if streamID == "" { + cancel() + return nil, fmt.Errorf("host model stream bridge is unavailable") + } + if req.HostCallbackID != "" { + h.addCallbackCleanup(req.HostCallbackID, func() { + h.modelStreams.close(streamID) + }) + } + return marshalRPCResult(pluginapi.HostModelStreamResponse{ + StatusCode: stream.StatusCode, + Headers: cloneHeader(stream.Headers), + StreamID: streamID, + }) +} + +func (h *Host) callHostModelStreamRead(ctx context.Context, request []byte) ([]byte, error) { + var req pluginapi.HostModelStreamReadRequest + if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil { + return nil, fmt.Errorf("decode host model stream read request: %w", errUnmarshal) + } + if h == nil || h.modelStreams == nil { + return nil, fmt.Errorf("host model stream bridge is unavailable") + } + chunk, done, errRead := h.modelStreams.read(ctx, req.StreamID) + if errRead != nil { + return nil, errRead + } + resp := pluginapi.HostModelStreamReadResponse{ + Payload: append([]byte(nil), chunk.Payload...), + Done: done, + } + if chunk.Err != nil { + resp.Error = chunk.Err.Error() + resp.Done = true + } + return marshalRPCResult(resp) +} + +func (h *Host) callHostModelStreamClose(request []byte) ([]byte, error) { + var req pluginapi.HostModelStreamCloseRequest + if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil { + return nil, fmt.Errorf("decode host model stream close request: %w", errUnmarshal) + } + if h != nil && h.modelStreams != nil { + h.modelStreams.close(req.StreamID) + } + return marshalRPCResult(rpcEmptyResponse{}) +} + +func modelExecutionRequestFromPlugin(req pluginapi.HostModelExecutionRequest) handlers.ModelExecutionRequest { + return handlers.ModelExecutionRequest{ + EntryProtocol: req.EntryProtocol, + ExitProtocol: req.ExitProtocol, + Model: req.Model, + Stream: req.Stream, + Body: append([]byte(nil), req.Body...), + Headers: cloneHeader(req.Headers), + Query: cloneValues(req.Query), + Alt: req.Alt, + } +} + +func modelExecutionError(errMsg *interfaces.ErrorMessage) error { + if errMsg == nil { + return nil + } + if errMsg.Error != nil { + return errMsg.Error + } + if errMsg.StatusCode > 0 { + return fmt.Errorf("model execution failed with status %d", errMsg.StatusCode) + } + return fmt.Errorf("model execution failed") +} + func (h *Host) callHostLog(ctx context.Context, request []byte) ([]byte, error) { var req rpcHostLogRequest if errUnmarshal := json.Unmarshal(request, &req); errUnmarshal != nil { diff --git a/internal/pluginhost/host_callbacks_test.go b/internal/pluginhost/host_callbacks_test.go index a28f33da0..e0ca16a41 100644 --- a/internal/pluginhost/host_callbacks_test.go +++ b/internal/pluginhost/host_callbacks_test.go @@ -6,18 +6,34 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" log "github.com/sirupsen/logrus" ) +type fakeHostModelExecutor struct { + executeModel func(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) + executeModelStream func(context.Context, handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) +} + +func (e *fakeHostModelExecutor) ExecuteModel(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) { + return e.executeModel(ctx, req) +} + +func (e *fakeHostModelExecutor) ExecuteModelStream(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + return e.executeModelStream(ctx, req) +} + func TestHostHTTPDoCallbackUsesHostHTTPClient(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -217,6 +233,448 @@ func TestHostStreamCallbacksEmitAndClose(t *testing.T) { } } +func TestHostModelExecuteCallback(t *testing.T) { + host := New() + var got handlers.ModelExecutionRequest + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModel: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionResponse, *interfaces.ErrorMessage) { + got = req + return handlers.ModelExecutionResponse{ + StatusCode: http.StatusAccepted, + Headers: http.Header{"X-Model": []string{"ok"}}, + Body: []byte(`{"response":true}`), + }, nil + }, + }) + + rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{ + HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: "model-1", + Body: []byte(`{"request":true}`), + Headers: http.Header{"X-Request": []string{"yes"}}, + Query: url.Values{"alt": []string{"sse"}}, + Alt: "raw", + }, + }) + if errMarshal != nil { + t.Fatalf("marshal request: %v", errMarshal) + } + rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawReq) + if errCall != nil { + t.Fatalf("callFromPlugin() error = %v", errCall) + } + + resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelExecutionResponse](rawResp) + if errDecode != nil { + t.Fatalf("decode response: %v", errDecode) + } + if resp.StatusCode != http.StatusAccepted || string(resp.Body) != `{"response":true}` { + t.Fatalf("response = %#v, want accepted body", resp) + } + if resp.Headers.Get("X-Model") != "ok" { + t.Fatalf("X-Model = %q, want ok", resp.Headers.Get("X-Model")) + } + if got.EntryProtocol != "openai" || got.ExitProtocol != "claude" || got.Model != "model-1" || got.Stream { + t.Fatalf("request protocols/model/stream = %#v", got) + } + if string(got.Body) != `{"request":true}` { + t.Fatalf("request body = %q, want original body", got.Body) + } + if got.Headers.Get("X-Request") != "yes" { + t.Fatalf("request header = %q, want yes", got.Headers.Get("X-Request")) + } + if got.Query.Get("alt") != "sse" { + t.Fatalf("query alt = %q, want sse", got.Query.Get("alt")) + } + if got.Alt != "raw" { + t.Fatalf("alt = %q, want raw", got.Alt) + } +} + +func TestHostModelStreamClosesWithCallbackScope(t *testing.T) { + host := New() + ctxSeen := make(chan context.Context, 1) + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + ctxSeen <- ctx + return handlers.ModelExecutionStream{ + StatusCode: http.StatusOK, + Headers: http.Header{"X-Stream": []string{"ok"}}, + Chunks: make(chan handlers.ModelExecutionChunk), + }, nil + }, + }) + callbackID, closeCallback := host.openCallbackContext(context.Background()) + defer closeCallback() + + rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{ + HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + Body: []byte(`{"stream":true}`), + }, + HostCallbackID: callbackID, + }) + if errMarshal != nil { + t.Fatalf("marshal request: %v", errMarshal) + } + rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq) + if errCall != nil { + t.Fatalf("callFromPlugin() error = %v", errCall) + } + resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp) + if errDecode != nil { + t.Fatalf("decode response: %v", errDecode) + } + if resp.StreamID == "" { + t.Fatalf("stream id is empty: %#v", resp) + } + + var streamCtx context.Context + select { + case streamCtx = <-ctxSeen: + case <-time.After(time.Second): + t.Fatal("model executor was not called") + } + closeCallback() + select { + case <-streamCtx.Done(): + case <-time.After(time.Second): + t.Fatal("stream context was not canceled after callback scope closed") + } +} + +func TestHostModelStreamReadAfterCallbackCloseReturnsDone(t *testing.T) { + host := New() + chunks := make(chan handlers.ModelExecutionChunk) + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + return handlers.ModelExecutionStream{ + StatusCode: http.StatusOK, + Chunks: chunks, + }, nil + }, + }) + callbackID, closeCallback := host.openCallbackContext(context.Background()) + + rawReq, errMarshal := json.Marshal(rpcHostModelExecutionRequest{ + HostModelExecutionRequest: pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + Body: []byte(`{"stream":true}`), + }, + HostCallbackID: callbackID, + }) + if errMarshal != nil { + t.Fatalf("marshal request: %v", errMarshal) + } + rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq) + if errCall != nil { + t.Fatalf("execute stream callback error = %v", errCall) + } + resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp) + if errDecode != nil { + t.Fatalf("decode stream response: %v", errDecode) + } + if resp.StreamID == "" { + t.Fatalf("stream id is empty: %#v", resp) + } + + closeCallback() + readReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{StreamID: resp.StreamID}) + if errMarshal != nil { + t.Fatalf("marshal read request: %v", errMarshal) + } + readDone := make(chan pluginapi.HostModelStreamReadResponse, 1) + readErr := make(chan error, 1) + go func() { + rawRead, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq) + if errRead != nil { + readErr <- errRead + return + } + doneResp, errDecodeRead := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead) + if errDecodeRead != nil { + readErr <- errDecodeRead + return + } + readDone <- doneResp + }() + select { + case errRead := <-readErr: + t.Fatalf("read after callback close error = %v", errRead) + case doneResp := <-readDone: + if !doneResp.Done || len(doneResp.Payload) != 0 || doneResp.Error != "" { + t.Fatalf("read after callback close = %#v, want done without payload/error", doneResp) + } + case <-time.After(time.Second): + t.Fatal("read after callback close blocked") + } +} + +func TestHostModelExecuteStreamStartupErrorCleansUp(t *testing.T) { + host := New() + ctxSeen := make(chan context.Context, 1) + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + ctxSeen <- ctx + return handlers.ModelExecutionStream{}, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadGateway, + } + }, + }) + + rawReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + Body: []byte(`{"stream":true}`), + }) + if errMarshal != nil { + t.Fatalf("marshal request: %v", errMarshal) + } + rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq) + if errCall == nil { + t.Fatalf("execute stream callback error is nil, raw response = %q", rawResp) + } + if rawResp != nil { + t.Fatalf("raw response = %q, want nil on startup error", rawResp) + } + if !strings.Contains(errCall.Error(), "status 502") { + t.Fatalf("execute stream callback error = %v, want status 502", errCall) + } + + var streamCtx context.Context + select { + case streamCtx = <-ctxSeen: + case <-time.After(time.Second): + t.Fatal("model executor was not called") + } + select { + case <-streamCtx.Done(): + case <-time.After(time.Second): + t.Fatal("stream context was not canceled after startup error") + } + gotCount := hostModelStreamCountForTest(t, host) + if gotCount != 0 { + t.Fatalf("model stream count = %d, want 0", gotCount) + } +} + +func TestHostModelCallbacksValidateStreamMode(t *testing.T) { + host := New() + + rawExecuteReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + }) + if errMarshal != nil { + t.Fatalf("marshal execute request: %v", errMarshal) + } + _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawExecuteReq) + if errCall == nil || !strings.Contains(errCall.Error(), "host.model.execute requires stream=false") { + t.Fatalf("execute callback error = %v, want stream=false validation error", errCall) + } + + rawStreamReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: false, + }) + if errMarshal != nil { + t.Fatalf("marshal execute stream request: %v", errMarshal) + } + _, errCall = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawStreamReq) + if errCall == nil || !strings.Contains(errCall.Error(), "host.model.execute_stream requires stream=true") { + t.Fatalf("execute stream callback error = %v, want stream=true validation error", errCall) + } +} + +func TestHostModelCallbacksRequireExecutor(t *testing.T) { + host := New() + + rawExecuteReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + }) + if errMarshal != nil { + t.Fatalf("marshal execute request: %v", errMarshal) + } + _, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecute, rawExecuteReq) + if errCall == nil || !strings.Contains(errCall.Error(), "host model executor is unavailable") { + t.Fatalf("execute callback error = %v, want unavailable executor error", errCall) + } + + rawStreamReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + }) + if errMarshal != nil { + t.Fatalf("marshal execute stream request: %v", errMarshal) + } + _, errCall = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawStreamReq) + if errCall == nil || !strings.Contains(errCall.Error(), "host model executor is unavailable") { + t.Fatalf("execute stream callback error = %v, want unavailable executor error", errCall) + } +} + +func TestHostModelStreamReadAndCloseValidateStreamID(t *testing.T) { + host := New() + + rawReadReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{}) + if errMarshal != nil { + t.Fatalf("marshal read request: %v", errMarshal) + } + _, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, rawReadReq) + if errRead == nil || !strings.Contains(errRead.Error(), "model stream id is required") { + t.Fatalf("read callback error = %v, want required stream id error", errRead) + } + + rawCloseReq, errMarshal := json.Marshal(pluginapi.HostModelStreamCloseRequest{}) + if errMarshal != nil { + t.Fatalf("marshal close request: %v", errMarshal) + } + rawClose, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, rawCloseReq) + if errClose != nil { + t.Fatalf("close callback error = %v", errClose) + } + _, errDecode := decodeRPCEnvelope[rpcEmptyResponse](rawClose) + if errDecode != nil { + t.Fatalf("decode close response: %v", errDecode) + } +} + +func TestHostModelStreamReadReturnsPayloadAndTerminalError(t *testing.T) { + host := New() + chunks := make(chan handlers.ModelExecutionChunk, 2) + chunks <- handlers.ModelExecutionChunk{Payload: []byte("first")} + chunks <- handlers.ModelExecutionChunk{Err: &handlers.ModelExecutionStreamError{ + StatusCode: http.StatusBadGateway, + Message: "terminal boom", + }} + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + return handlers.ModelExecutionStream{ + StatusCode: http.StatusOK, + Headers: http.Header{"X-Stream": []string{"ok"}}, + Chunks: chunks, + }, nil + }, + }) + + streamID := openHostModelStreamForTest(t, host) + readReq, errMarshal := json.Marshal(pluginapi.HostModelStreamReadRequest{StreamID: streamID}) + if errMarshal != nil { + t.Fatalf("marshal read request: %v", errMarshal) + } + rawRead, errRead := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq) + if errRead != nil { + t.Fatalf("read callback error = %v", errRead) + } + first, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead) + if errDecode != nil { + t.Fatalf("decode read response: %v", errDecode) + } + if string(first.Payload) != "first" || first.Done || first.Error != "" { + t.Fatalf("first read = %#v, want payload without done", first) + } + + rawRead, errRead = host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamRead, readReq) + if errRead != nil { + t.Fatalf("terminal read callback error = %v", errRead) + } + terminal, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamReadResponse](rawRead) + if errDecode != nil { + t.Fatalf("decode terminal response: %v", errDecode) + } + if !terminal.Done || terminal.Error != "terminal boom" || len(terminal.Payload) != 0 { + t.Fatalf("terminal read = %#v, want done terminal error", terminal) + } +} + +func TestHostModelStreamExplicitCloseCancelsStream(t *testing.T) { + host := New() + ctxSeen := make(chan context.Context, 1) + host.SetModelExecutor(&fakeHostModelExecutor{ + executeModelStream: func(ctx context.Context, req handlers.ModelExecutionRequest) (handlers.ModelExecutionStream, *interfaces.ErrorMessage) { + ctxSeen <- ctx + return handlers.ModelExecutionStream{ + StatusCode: http.StatusOK, + Chunks: make(chan handlers.ModelExecutionChunk), + }, nil + }, + }) + + streamID := openHostModelStreamForTest(t, host) + var streamCtx context.Context + select { + case streamCtx = <-ctxSeen: + case <-time.After(time.Second): + t.Fatal("model executor was not called") + } + closeReq, errMarshal := json.Marshal(pluginapi.HostModelStreamCloseRequest{StreamID: streamID}) + if errMarshal != nil { + t.Fatalf("marshal close request: %v", errMarshal) + } + if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, closeReq); errClose != nil { + t.Fatalf("close callback error = %v", errClose) + } + select { + case <-streamCtx.Done(): + case <-time.After(time.Second): + t.Fatal("stream context was not canceled after explicit close") + } + if _, errClose := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelStreamClose, closeReq); errClose != nil { + t.Fatalf("second close callback error = %v", errClose) + } +} + +func openHostModelStreamForTest(t *testing.T, host *Host) string { + t.Helper() + rawReq, errMarshal := json.Marshal(pluginapi.HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "openai", + Model: "model-1", + Stream: true, + Body: []byte(`{"stream":true}`), + }) + if errMarshal != nil { + t.Fatalf("marshal request: %v", errMarshal) + } + rawResp, errCall := host.callFromPlugin(context.Background(), pluginabi.MethodHostModelExecuteStream, rawReq) + if errCall != nil { + t.Fatalf("execute stream callback error = %v", errCall) + } + resp, errDecode := decodeRPCEnvelope[pluginapi.HostModelStreamResponse](rawResp) + if errDecode != nil { + t.Fatalf("decode stream response: %v", errDecode) + } + if resp.StreamID == "" { + t.Fatalf("stream id is empty: %#v", resp) + } + return resp.StreamID +} + +func hostModelStreamCountForTest(t *testing.T, host *Host) int { + t.Helper() + host.modelStreams.mu.Lock() + defer host.modelStreams.mu.Unlock() + return len(host.modelStreams.streams) +} + func TestHostLogCallbackRestoresRegisteredRequestContext(t *testing.T) { host := New() ctx := logging.WithRequestID(context.Background(), "request-123") diff --git a/internal/pluginhost/host_test.go b/internal/pluginhost/host_test.go index 72dc93629..78354a5f1 100644 --- a/internal/pluginhost/host_test.go +++ b/internal/pluginhost/host_test.go @@ -3,6 +3,7 @@ package pluginhost import ( "context" "encoding/json" + "net/http" "testing" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" @@ -268,6 +269,40 @@ func TestRPCInterceptorsIncludeHostCallbackID(t *testing.T) { } } +func TestRPCManagementIncludesHostCallbackID(t *testing.T) { + client := &capturePluginClient{} + host := New() + adapter := &rpcPluginAdapter{ + host: host, + client: client, + } + + if _, errHandle := adapter.HandleManagement(context.Background(), pluginapi.ManagementRequest{ + Method: http.MethodGet, + Path: "/v0/management/plugins/test/status", + Body: []byte("request"), + }); errHandle != nil { + t.Fatalf("HandleManagement() error = %v", errHandle) + } + var req rpcManagementRequest + if errDecode := json.Unmarshal(client.requests[pluginabi.MethodManagementHandle], &req); errDecode != nil { + t.Fatalf("decode management request: %v", errDecode) + } + if req.HostCallbackID == "" { + t.Fatal("management handle host_callback_id is empty") + } + if req.Method != http.MethodGet || req.Path != "/v0/management/plugins/test/status" || string(req.Body) != "request" { + t.Fatalf("management request = %#v, want forwarded request fields", req.ManagementRequest) + } + + host.callbackContexts.mu.RLock() + _, exists := host.callbackContexts.contexts[req.HostCallbackID] + host.callbackContexts.mu.RUnlock() + if exists { + t.Fatal("management host_callback_id scope was not closed") + } +} + func TestSanitizePluginRequestRemovesNonJSONMetadata(t *testing.T) { req := pluginapi.RequestInterceptRequest{ Metadata: map[string]any{ diff --git a/internal/pluginhost/model_stream_bridge.go b/internal/pluginhost/model_stream_bridge.go new file mode 100644 index 000000000..7ee61326b --- /dev/null +++ b/internal/pluginhost/model_stream_bridge.go @@ -0,0 +1,91 @@ +package pluginhost + +import ( + "context" + "fmt" + "strconv" + "sync" + "sync/atomic" + + "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" +) + +type modelStreamBridge struct { + next atomic.Uint64 + mu sync.Mutex + streams map[string]modelStreamEntry +} + +type modelStreamEntry struct { + ownerCallbackID string + chunks <-chan handlers.ModelExecutionChunk + cancel context.CancelFunc +} + +func newModelStreamBridge() *modelStreamBridge { + return &modelStreamBridge{streams: make(map[string]modelStreamEntry)} +} + +func (b *modelStreamBridge) open(ownerCallbackID string, chunks <-chan handlers.ModelExecutionChunk, cancel context.CancelFunc) string { + if b == nil || chunks == nil { + if cancel != nil { + cancel() + } + return "" + } + id := strconv.FormatUint(b.next.Add(1), 10) + b.mu.Lock() + b.streams[id] = modelStreamEntry{ + ownerCallbackID: ownerCallbackID, + chunks: chunks, + cancel: cancel, + } + b.mu.Unlock() + return id +} + +func (b *modelStreamBridge) read(ctx context.Context, id string) (handlers.ModelExecutionChunk, bool, error) { + if b == nil { + return handlers.ModelExecutionChunk{}, true, fmt.Errorf("model stream bridge is unavailable") + } + if id == "" { + return handlers.ModelExecutionChunk{}, true, fmt.Errorf("model stream id is required") + } + b.mu.Lock() + entry, ok := b.streams[id] + b.mu.Unlock() + if !ok || entry.chunks == nil { + return handlers.ModelExecutionChunk{}, true, nil + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + b.close(id) + return handlers.ModelExecutionChunk{}, true, ctx.Err() + case chunk, okRead := <-entry.chunks: + if !okRead { + b.close(id) + return handlers.ModelExecutionChunk{}, true, nil + } + if chunk.Err != nil { + b.close(id) + return chunk, true, nil + } + return chunk, false, nil + } +} + +func (b *modelStreamBridge) close(id string) { + if b == nil || id == "" { + return + } + b.mu.Lock() + entry := b.streams[id] + delete(b.streams, id) + b.mu.Unlock() + if entry.cancel != nil { + entry.cancel() + } +} diff --git a/internal/pluginhost/rpc_client.go b/internal/pluginhost/rpc_client.go index ff69bb209..6ef163116 100644 --- a/internal/pluginhost/rpc_client.go +++ b/internal/pluginhost/rpc_client.go @@ -516,7 +516,12 @@ func (a *rpcPluginAdapter) RegisterManagement(ctx context.Context, req pluginapi } func (a *rpcPluginAdapter) HandleManagement(ctx context.Context, req pluginapi.ManagementRequest) (pluginapi.ManagementResponse, error) { - return callPlugin[pluginapi.ManagementResponse](ctx, a.client, pluginabi.MethodManagementHandle, req) + callbackID, closeCallback := a.openHostCallbackContext(ctx) + defer closeCallback() + return callPlugin[pluginapi.ManagementResponse](ctx, a.client, pluginabi.MethodManagementHandle, rpcManagementRequest{ + ManagementRequest: req, + HostCallbackID: callbackID, + }) } func httpResponseFromPlugin(resp pluginapi.ExecutorHTTPResponse, req *http.Request) *http.Response { diff --git a/internal/pluginhost/rpc_schema.go b/internal/pluginhost/rpc_schema.go index bf2527266..1d4b10ff3 100644 --- a/internal/pluginhost/rpc_schema.go +++ b/internal/pluginhost/rpc_schema.go @@ -102,6 +102,11 @@ type rpcThinkingApplyRequest struct { HostCallbackID string `json:"host_callback_id,omitempty"` } +type rpcManagementRequest struct { + pluginapi.ManagementRequest + HostCallbackID string `json:"host_callback_id,omitempty"` +} + type rpcManagementRegistrationResponse struct { Routes []pluginapi.ManagementRoute `json:"routes,omitempty"` Resources []pluginapi.ResourceRoute `json:"resources,omitempty"` diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index ea6fccf83..ab5889352 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -184,8 +184,9 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} } reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body)) + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m) + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, responseFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m) resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()} return resp, nil } @@ -289,6 +290,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth out := make(chan cliproxyexecutor.StreamChunk) go func(first wsrelay.StreamEvent) { defer close(out) + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) var param any metadataLogged := false processEvent := func(event wsrelay.StreamEvent) bool { @@ -316,7 +318,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok { reporter.Publish(ctx, detail) } - lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m) + lines := sdktranslator.TranslateStream(ctx, body.toFormat, responseFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}: @@ -338,7 +340,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth reporter.MarkFirstResponseByte() helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload) } - lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m) + lines := sdktranslator.TranslateStream(ctx, body.toFormat, responseFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}: @@ -423,7 +425,8 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A if totalTokens <= 0 { return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response") } - translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body) + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) + translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, responseFormat, totalTokens, resp.Body) return cliproxyexecutor.Response{Payload: translated}, nil } diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index affde053f..2889ca144 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -543,6 +543,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("antigravity") originalPayloadSource := req.Payload @@ -710,7 +711,7 @@ attemptLoop: } reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes)) var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) + converted := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()} reporter.EnsurePublished(ctx) return resp, nil @@ -743,6 +744,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("antigravity") originalPayloadSource := req.Payload @@ -973,7 +975,7 @@ attemptLoop: reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload)) var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m) + converted := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m) resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()} reporter.EnsurePublished(ctx) @@ -1205,6 +1207,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("antigravity") originalPayloadSource := req.Payload @@ -1411,7 +1414,7 @@ attemptLoop: reporter.Publish(ctx, detail) } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -1420,7 +1423,7 @@ attemptLoop: } } } - tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m) + tail := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m) for i := range tail { select { case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}: @@ -1511,6 +1514,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) originalPayloadSource := req.Payload @@ -1631,7 +1635,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { count := gjson.GetBytes(bodyBytes, "totalTokens").Int() - translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes) + translated := sdktranslator.TranslateTokenCount(respCtx, to, responseFormat, count, bodyBytes) return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 3766900e0..b306b5a76 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -174,6 +174,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to @@ -332,7 +333,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r out := sdktranslator.TranslateNonStream( ctx, to, - from, + responseFormat, req.Model, opts.OriginalRequest, bodyForTranslation, @@ -357,6 +358,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("claude") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -488,8 +490,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } }() - // If from == to (Claude → Claude), directly forward the SSE stream without translation - if from == to { + // If the response target is Claude, directly forward the SSE stream without translation. + if responseFormat == to { scanner := bufio.NewScanner(decodedBody) scanner.Buffer(nil, 52_428_800) // 50MB for scanner.Scan() { @@ -534,7 +536,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A chunks := sdktranslator.TranslateStream( ctx, to, - from, + responseFormat, req.Model, opts.OriginalRequest, bodyForTranslation, @@ -628,6 +630,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to @@ -725,7 +728,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } helps.AppendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "input_tokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + out := sdktranslator.TranslateTokenCount(ctx, to, responseFormat, count, data) return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 73187963c..776408fc8 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -790,6 +790,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -941,7 +942,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re var param any clientCompletedData := applyCodexIdentityExposeResponsePayload(completedData, identityState) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, clientCompletedData, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, originalPayload, body, clientCompletedData, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -961,6 +962,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("openai-response") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -1043,7 +1045,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A reporter.EnsurePublished(ctx) var param any clientData := applyCodexIdentityExposeResponsePayload(upstreamData, identityState) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, clientData, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, originalPayload, body, clientData, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -1066,6 +1068,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -1190,7 +1193,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } translatedLine = applyCodexIdentityExposeResponsePayload(translatedLine, identityState) - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, originalPayload, body, translatedLine, ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -1215,6 +1218,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) @@ -1242,7 +1246,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth } usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count) - translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON)) + translated := sdktranslator.TranslateTokenCount(ctx, to, responseFormat, count, []byte(usageJSON)) return cliproxyexecutor.Response{Payload: translated}, nil } diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 8d68a251e..603d20e54 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -188,6 +188,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -382,7 +383,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } var param any clientPayload := applyCodexIdentityExposeResponsePayload(payload, identityState) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, clientBody, clientPayload, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, originalPayload, clientBody, clientPayload, ¶m) resp = cliproxyexecutor.Response{Payload: out} return resp, nil } @@ -408,6 +409,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") body := req.Payload userPayload := req.Payload @@ -652,7 +654,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr clientPayload := applyCodexIdentityExposeResponsePayload(payload, identityState) line := encodeCodexWebsocketAsSSE(clientPayload) - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, clientBody, clientBody, line, ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, clientBody, clientBody, line, ¶m) for i := range chunks { if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) { terminateReason = "context_done" diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 0d15e1d0e..7055f8ad0 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -122,6 +122,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini-cli") originalPayloadSource := req.Payload @@ -234,7 +235,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data)) var param any - out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m) + out := sdktranslator.TranslateNonStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, payload, data, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -281,6 +282,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini-cli") originalPayloadSource := req.Payload @@ -415,7 +417,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut reporter.Publish(ctx, detail) } if bytes.HasPrefix(line, dataTag) { - segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m) for i := range segments { select { case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: @@ -426,7 +428,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } } - segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) for i := range segments { select { case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: @@ -460,7 +462,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut helps.AppendAPIResponseChunk(ctx, e.cfg, data) reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data)) var param any - segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, data, ¶m) for i := range segments { select { case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: @@ -469,7 +471,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } } - segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) + segments = sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) for i := range segments { select { case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: @@ -502,6 +504,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. } from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini-cli") models := cliPreviewFallbackOrder(baseModel) @@ -587,7 +590,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. helps.AppendAPIResponseChunk(ctx, e.cfg, data) if resp.StatusCode >= 200 && resp.StatusCode < 300 { count := gjson.GetBytes(data, "totalTokens").Int() - translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) + translated := sdktranslator.TranslateTokenCount(respCtx, to, responseFormat, count, data) return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil } lastStatus = resp.StatusCode diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 585a06425..6f502a737 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -117,6 +117,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -210,7 +211,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r helps.AppendAPIResponseChunk(ctx, e.cfg, data) reporter.Publish(ctx, helps.ParseGeminiUsage(data)) var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, data, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -228,6 +229,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -329,7 +331,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := helps.ParseGeminiStreamUsage(payload); ok { reporter.Publish(ctx, detail) } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -338,7 +340,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } } } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -365,6 +367,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut apiKey, bearer := geminiCreds(auth) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) @@ -439,7 +442,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } count := gjson.GetBytes(data, "totalTokens").Int() - translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) + translated := sdktranslator.TranslateTokenCount(respCtx, to, responseFormat, count, data) return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 75d31844b..b0677415a 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -429,10 +429,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } // Standard Gemini translation (works for both Gemini and converted Imagen responses) - from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, data, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -445,6 +445,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") originalPayloadSource := req.Payload @@ -546,7 +547,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip helps.AppendAPIResponseChunk(ctx, e.cfg, data) reporter.Publish(ctx, helps.ParseGeminiUsage(data)) var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, data, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -559,6 +560,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") originalPayloadSource := req.Payload @@ -666,7 +668,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte if detail, ok := helps.ParseGeminiStreamUsage(line); ok { reporter.Publish(ctx, detail) } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -675,7 +677,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } } } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -703,6 +705,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth defer reporter.TrackFailure(ctx, &err) from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") originalPayloadSource := req.Payload @@ -810,7 +813,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth if detail, ok := helps.ParseGeminiStreamUsage(line); ok { reporter.Publish(ctx, detail) } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -819,7 +822,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } } } - lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) + lines := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) for i := range lines { select { case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}: @@ -844,6 +847,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) @@ -925,7 +929,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context } helps.AppendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + out := sdktranslator.TranslateTokenCount(ctx, to, responseFormat, count, data) return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil } @@ -934,6 +938,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) @@ -1015,7 +1020,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * } helps.AppendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + out := sdktranslator.TranslateTokenCount(ctx, to, responseFormat, count, data) return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil } diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index ef3fff11c..f296687f6 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -78,6 +78,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL return e.ClaudeExecutor.Execute(ctx, auth, req, opts) } + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) baseModel := thinking.ParseSuffix(req.Model).ModelName @@ -175,7 +176,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req var param any // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, data, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -187,6 +188,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts) } + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) baseModel := thinking.ParseSuffix(req.Model).ModelName token := kimiCreds(auth) @@ -292,7 +294,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut if detail, ok := helps.ParseOpenAIStreamUsage(line); ok { reporter.Publish(ctx, detail) } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -301,7 +303,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } } } - doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) + doneChunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) for i := range doneChunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}: diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 5013eb909..5bfba83df 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -99,6 +99,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("openai") endpoint := "/chat/completions" if opts.Alt == "responses/compact" { @@ -193,7 +194,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A reporter.EnsurePublished(ctx) // Translate response back to source format when needed var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, body, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } @@ -304,6 +305,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy } from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("openai") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -421,7 +423,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy } // OpenAI-compatible streams must use SSE data lines. - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -441,7 +443,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy // In case the upstream close the stream without a terminal [DONE] marker. // Feed a synthetic done marker through the translator so pending // response.completed events are still emitted exactly once. - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, responseFormat, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -577,6 +579,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("openai") translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) @@ -598,7 +601,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau } usageJSON := helps.BuildOpenAIUsageJSON(count) - translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) + translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, responseFormat, count, usageJSON) return cliproxyexecutor.Response{Payload: translatedUsage}, nil } diff --git a/internal/runtime/executor/xai_executor.go b/internal/runtime/executor/xai_executor.go index aeab85d7a..4dbc029b3 100644 --- a/internal/runtime/executor/xai_executor.go +++ b/internal/runtime/executor/xai_executor.go @@ -173,7 +173,7 @@ func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback) var param any - out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m) + out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.responseFormat, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m) return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil } } @@ -366,7 +366,7 @@ func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth translatedLine = append([]byte("data: "), eventData...) } } - chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m) + chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.responseFormat, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m) for i := range chunks { select { case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: @@ -402,7 +402,7 @@ func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err) } usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count) - translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON)) + translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.responseFormat, int64(count), []byte(usageJSON)) return cliproxyexecutor.Response{Payload: translated}, nil } @@ -472,6 +472,7 @@ func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cl type xaiPreparedRequest struct { baseModel string from sdktranslator.Format + responseFormat sdktranslator.Format to sdktranslator.Format originalPayload []byte body []byte @@ -481,6 +482,7 @@ type xaiPreparedRequest struct { func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName from := opts.SourceFormat + responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("codex") originalPayloadSource := req.Payload if len(opts.OriginalRequest) > 0 { @@ -519,6 +521,7 @@ func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxye return &xaiPreparedRequest{ baseModel: baseModel, from: from, + responseFormat: responseFormat, to: to, originalPayload: originalPayload, body: body, diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 42756dc08..6ad218550 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -628,13 +628,19 @@ func (h *BaseAPIHandler) ExecuteImageWithAuthManager(ctx context.Context, handle } func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) ([]byte, http.Header, *interfaces.ErrorMessage) { + return h.executeWithAuthManagerFormats(ctx, handlerType, handlerType, modelName, rawJSON, alt, allowImageModel, modelExecutionOptions{}) +} + +func (h *BaseAPIHandler) executeWithAuthManagerFormats(ctx context.Context, entryProtocol, exitProtocol, modelName string, rawJSON []byte, alt string, allowImageModel bool, execOptions modelExecutionOptions) ([]byte, http.Header, *interfaces.ErrorMessage) { + responseProtocol := modelExecutionResponseProtocol(entryProtocol, exitProtocol) providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel) if errMsg != nil { return nil, nil, errMsg } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName - setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON) + addModelExecutionSourceMetadata(reqMeta, execOptions.InternalSource) + setReasoningEffortMetadata(reqMeta, entryProtocol, normalizedModel, rawJSON) setServiceTierMetadata(reqMeta, rawJSON) payload := rawJSON if len(payload) == 0 { @@ -649,12 +655,14 @@ func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType Stream: false, Alt: alt, OriginalRequest: rawJSON, - SourceFormat: sdktranslator.FromString(handlerType), - Headers: headersFromContext(ctx), + SourceFormat: sdktranslator.FromString(entryProtocol), + ResponseFormat: sdktranslator.FromString(responseProtocol), + Headers: modelExecutionHeaders(ctx, execOptions.Headers), + Query: cloneURLValues(execOptions.Query), RequestAfterAuthInterceptor: h.requestAfterAuthInterceptor(afterAuthCapture), } opts.Metadata = reqMeta - req, opts = h.applyRequestInterceptorsBeforeAuth(ctx, handlerType, modelName, req, opts) + req, opts = h.applyRequestInterceptorsBeforeAuth(ctx, entryProtocol, modelName, req, opts) resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { err = enrichAuthSelectionError(err, providers, normalizedModel) @@ -675,7 +683,7 @@ func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType executedReq, executedOpts := afterAuthCapture.apply(req, opts) rawResponseHeaders := cloneHeader(resp.Headers) responseHeaders := downstreamHeadersFromExecutor(rawResponseHeaders, PassthroughHeadersEnabled(h.Cfg)) - body, responseHeaders := h.applyResponseInterceptors(ctx, handlerType, normalizedModel, modelName, executedOpts, rawResponseHeaders, responseHeaders, executedOpts.OriginalRequest, executedReq.Payload, resp.Payload, http.StatusOK) + body, responseHeaders := h.applyResponseInterceptors(ctx, responseProtocol, normalizedModel, modelName, executedOpts, rawResponseHeaders, responseHeaders, executedOpts.OriginalRequest, executedReq.Payload, resp.Payload, http.StatusOK) return body, responseHeaders, nil } @@ -746,6 +754,11 @@ func (h *BaseAPIHandler) ExecuteImageStreamWithAuthManager(ctx context.Context, } func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) { + return h.executeStreamWithAuthManagerFormats(ctx, handlerType, handlerType, modelName, rawJSON, alt, allowImageModel, modelExecutionOptions{}) +} + +func (h *BaseAPIHandler) executeStreamWithAuthManagerFormats(ctx context.Context, entryProtocol, exitProtocol, modelName string, rawJSON []byte, alt string, allowImageModel bool, execOptions modelExecutionOptions) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) { + responseProtocol := modelExecutionResponseProtocol(entryProtocol, exitProtocol) providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel) if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -755,7 +768,8 @@ func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handl } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName - setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON) + addModelExecutionSourceMetadata(reqMeta, execOptions.InternalSource) + setReasoningEffortMetadata(reqMeta, entryProtocol, normalizedModel, rawJSON) setServiceTierMetadata(reqMeta, rawJSON) payload := rawJSON if len(payload) == 0 { @@ -770,12 +784,14 @@ func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handl Stream: true, Alt: alt, OriginalRequest: rawJSON, - SourceFormat: sdktranslator.FromString(handlerType), - Headers: headersFromContext(ctx), + SourceFormat: sdktranslator.FromString(entryProtocol), + ResponseFormat: sdktranslator.FromString(responseProtocol), + Headers: modelExecutionHeaders(ctx, execOptions.Headers), + Query: cloneURLValues(execOptions.Query), RequestAfterAuthInterceptor: h.requestAfterAuthInterceptor(afterAuthCapture), } opts.Metadata = reqMeta - req, opts = h.applyRequestInterceptorsBeforeAuth(ctx, handlerType, modelName, req, opts) + req, opts = h.applyRequestInterceptorsBeforeAuth(ctx, entryProtocol, modelName, req, opts) streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { err = enrichAuthSelectionError(err, providers, normalizedModel) @@ -831,7 +847,7 @@ func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handl } executedReq, executedOpts := executedRequest() intercepted := interceptorHost.InterceptStreamChunk(ctx, pluginapi.StreamChunkInterceptRequest{ - SourceFormat: handlerType, + SourceFormat: responseProtocol, Model: normalizedModel, RequestedModel: modelName, RequestHeaders: cloneHeader(executedOpts.Headers), @@ -986,7 +1002,7 @@ func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handl if streamInterceptorsActive { executedReq, executedOpts := executedRequest() intercepted := interceptorHost.InterceptStreamChunk(ctx, pluginapi.StreamChunkInterceptRequest{ - SourceFormat: handlerType, + SourceFormat: responseProtocol, Model: normalizedModel, RequestedModel: modelName, RequestHeaders: cloneHeader(executedOpts.Headers), @@ -1009,7 +1025,7 @@ func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handl } else { chunkIndex++ } - if handlerType == "openai-response" { + if responseProtocol == "openai-response" { if errValidate := validateSSEDataJSON(payload); errValidate != nil { _ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: errValidate}) return diff --git a/sdk/api/handlers/model_execution.go b/sdk/api/handlers/model_execution.go new file mode 100644 index 000000000..e004fea2c --- /dev/null +++ b/sdk/api/handlers/model_execution.go @@ -0,0 +1,252 @@ +package handlers + +import ( + "errors" + "net/http" + "net/url" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "golang.org/x/net/context" +) + +const ( + modelExecutionMetadataSourceKey = "source" + modelExecutionInternalSource = "plugin_host_model_callback" +) + +type modelExecutionOptions struct { + Headers http.Header + Query url.Values + InternalSource bool +} + +// ModelExecutionRequest describes an internal model execution request. +type ModelExecutionRequest struct { + EntryProtocol string + ExitProtocol string + Model string + Stream bool + Body []byte + Headers http.Header + Query url.Values + Alt string +} + +// ModelExecutionResponse describes a non-streaming internal model execution response. +type ModelExecutionResponse struct { + StatusCode int + Headers http.Header + Body []byte +} + +// ModelExecutionStream describes a streaming internal model execution response. +type ModelExecutionStream struct { + StatusCode int + Headers http.Header + Chunks <-chan ModelExecutionChunk +} + +// ModelExecutionChunk carries either a streaming payload or a terminal stream error. +type ModelExecutionChunk struct { + Payload []byte + Err *ModelExecutionStreamError +} + +// ModelExecutionStreamError carries a JSON-friendly terminal stream error. +type ModelExecutionStreamError struct { + StatusCode int `json:"status_code"` + Message string `json:"message"` + Headers http.Header `json:"headers"` +} + +// Error returns the stream error message or the HTTP status text. +func (e *ModelExecutionStreamError) Error() string { + if e == nil { + return "" + } + if e.Message != "" { + return e.Message + } + return http.StatusText(e.StatusCode) +} + +// ExecuteModel executes an internal non-streaming model request. +func (h *BaseAPIHandler) ExecuteModel(ctx context.Context, req ModelExecutionRequest) (ModelExecutionResponse, *interfaces.ErrorMessage) { + if req.Stream { + return ModelExecutionResponse{}, modelExecutionModeError("ExecuteModel requires Stream=false") + } + body, headers, errMsg := h.executeWithAuthManagerFormats(ctx, req.EntryProtocol, req.ExitProtocol, req.Model, cloneBytes(req.Body), req.Alt, false, modelExecutionOptions{ + Headers: req.Headers, + Query: req.Query, + InternalSource: true, + }) + if errMsg != nil { + return ModelExecutionResponse{}, errMsg + } + return ModelExecutionResponse{ + StatusCode: http.StatusOK, + Headers: cloneHeader(headers), + Body: cloneBytes(body), + }, nil +} + +// ExecuteModelStream executes an internal streaming model request. +func (h *BaseAPIHandler) ExecuteModelStream(ctx context.Context, req ModelExecutionRequest) (ModelExecutionStream, *interfaces.ErrorMessage) { + if !req.Stream { + return ModelExecutionStream{}, modelExecutionModeError("ExecuteModelStream requires Stream=true") + } + dataChan, headers, errChan := h.executeStreamWithAuthManagerFormats(ctx, req.EntryProtocol, req.ExitProtocol, req.Model, cloneBytes(req.Body), req.Alt, false, modelExecutionOptions{ + Headers: req.Headers, + Query: req.Query, + InternalSource: true, + }) + chunks, errMsg := prepareModelExecutionStream(ctx, dataChan, errChan) + if errMsg != nil { + return ModelExecutionStream{}, errMsg + } + return ModelExecutionStream{ + StatusCode: http.StatusOK, + Headers: cloneHeader(headers), + Chunks: chunks, + }, nil +} + +func modelExecutionModeError(message string) *interfaces.ErrorMessage { + return &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: errors.New(message)} +} + +func modelExecutionResponseProtocol(entryProtocol, exitProtocol string) string { + if exitProtocol == "" { + return entryProtocol + } + return exitProtocol +} + +func modelExecutionHeaders(ctx context.Context, headers http.Header) http.Header { + if len(headers) > 0 { + return cloneHeader(headers) + } + return headersFromContext(ctx) +} + +func cloneURLValues(src url.Values) url.Values { + if src == nil { + return nil + } + dst := make(url.Values, len(src)) + for key, values := range src { + dst[key] = append([]string(nil), values...) + } + return dst +} + +func addModelExecutionSourceMetadata(meta map[string]any, internalSource bool) { + if !internalSource || meta == nil { + return + } + meta[modelExecutionMetadataSourceKey] = modelExecutionInternalSource +} + +func prepareModelExecutionStream(ctx context.Context, dataChan <-chan []byte, errChan <-chan *interfaces.ErrorMessage) (<-chan ModelExecutionChunk, *interfaces.ErrorMessage) { + pending, nextDataChan, nextErrChan, errMsg := receiveInitialModelExecutionChunk(ctx, dataChan, errChan) + if errMsg != nil { + return nil, errMsg + } + return wrapModelExecutionChunks(ctx, nextDataChan, nextErrChan, pending), nil +} + +func receiveInitialModelExecutionChunk(ctx context.Context, dataChan <-chan []byte, errChan <-chan *interfaces.ErrorMessage) ([]ModelExecutionChunk, <-chan []byte, <-chan *interfaces.ErrorMessage, *interfaces.ErrorMessage) { + var done <-chan struct{} + if ctx != nil { + done = ctx.Done() + } + for dataChan != nil || errChan != nil { + select { + case payload, ok := <-dataChan: + if !ok { + dataChan = nil + continue + } + return []ModelExecutionChunk{{Payload: cloneBytes(payload)}}, dataChan, errChan, nil + case errMsg, ok := <-errChan: + if !ok { + errChan = nil + continue + } + if errMsg != nil { + return nil, dataChan, errChan, errMsg + } + case <-done: + return nil, dataChan, errChan, nil + } + } + return nil, dataChan, errChan, nil +} + +func wrapModelExecutionChunks(ctx context.Context, dataChan <-chan []byte, errChan <-chan *interfaces.ErrorMessage, pending []ModelExecutionChunk) <-chan ModelExecutionChunk { + chunks := make(chan ModelExecutionChunk) + go func() { + defer close(chunks) + var done <-chan struct{} + if ctx != nil { + done = ctx.Done() + } + for _, chunk := range pending { + if !sendModelExecutionChunk(ctx, chunks, chunk) { + return + } + } + for dataChan != nil || errChan != nil { + select { + case <-done: + return + case payload, ok := <-dataChan: + if !ok { + dataChan = nil + continue + } + if !sendModelExecutionChunk(ctx, chunks, ModelExecutionChunk{Payload: cloneBytes(payload)}) { + return + } + case errMsg, ok := <-errChan: + if !ok { + errChan = nil + continue + } + if errMsg != nil { + _ = sendModelExecutionChunk(ctx, chunks, ModelExecutionChunk{Err: modelExecutionStreamErrorFromMessage(errMsg)}) + return + } + } + } + }() + return chunks +} + +func modelExecutionStreamErrorFromMessage(errMsg *interfaces.ErrorMessage) *ModelExecutionStreamError { + if errMsg == nil { + return nil + } + message := "" + if errMsg.Error != nil { + message = errMsg.Error.Error() + } + return &ModelExecutionStreamError{ + StatusCode: errMsg.StatusCode, + Message: message, + Headers: cloneHeader(errMsg.Addon), + } +} + +func sendModelExecutionChunk(ctx context.Context, chunks chan<- ModelExecutionChunk, chunk ModelExecutionChunk) bool { + if ctx == nil { + chunks <- chunk + return true + } + select { + case <-ctx.Done(): + return false + case chunks <- chunk: + return true + } +} diff --git a/sdk/api/handlers/model_execution_test.go b/sdk/api/handlers/model_execution_test.go new file mode 100644 index 000000000..642fcf42a --- /dev/null +++ b/sdk/api/handlers/model_execution_test.go @@ -0,0 +1,392 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sync" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" +) + +type modelExecutionCaptureExecutor struct { + provider string + + mu sync.Mutex + lastRequest coreexecutor.Request + lastOptions coreexecutor.Options + execute func(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) + stream func(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) +} + +type modelExecutionStatusHeaderError struct { + statusCode int + message string + headers http.Header +} + +func (e modelExecutionStatusHeaderError) Error() string { + return e.message +} + +func (e modelExecutionStatusHeaderError) StatusCode() int { + return e.statusCode +} + +func (e modelExecutionStatusHeaderError) Headers() http.Header { + return e.headers +} + +func (e *modelExecutionCaptureExecutor) Identifier() string { + if e.provider != "" { + return e.provider + } + return "codex" +} + +func (e *modelExecutionCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) { + e.capture(req, opts) + if e.execute != nil { + return e.execute(ctx, auth, req, opts) + } + return coreexecutor.Response{Payload: []byte("model-execution-ok")}, nil +} + +func (e *modelExecutionCaptureExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.capture(req, opts) + if e.stream != nil { + return e.stream(ctx, auth, req, opts) + } + chunks := make(chan coreexecutor.StreamChunk) + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *modelExecutionCaptureExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *modelExecutionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{Payload: []byte("0")}, nil +} + +func (e *modelExecutionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{Code: "not_implemented", Message: "HttpRequest not implemented", HTTPStatus: http.StatusNotImplemented} +} + +func (e *modelExecutionCaptureExecutor) capture(req coreexecutor.Request, opts coreexecutor.Options) { + e.mu.Lock() + defer e.mu.Unlock() + e.lastRequest = coreexecutor.Request{ + Model: req.Model, + Payload: cloneBytes(req.Payload), + Format: req.Format, + Metadata: req.Metadata, + } + e.lastOptions = coreexecutor.Options{ + Stream: opts.Stream, + Alt: opts.Alt, + Headers: cloneHeader(opts.Headers), + Query: cloneURLValues(opts.Query), + OriginalRequest: cloneBytes(opts.OriginalRequest), + SourceFormat: opts.SourceFormat, + ResponseFormat: opts.ResponseFormat, + Metadata: opts.Metadata, + } +} + +func (e *modelExecutionCaptureExecutor) captured() (coreexecutor.Request, coreexecutor.Options) { + e.mu.Lock() + defer e.mu.Unlock() + return e.lastRequest, e.lastOptions +} + +func newModelExecutionHandler(t *testing.T, model string, executor *modelExecutionCaptureExecutor, cfg *sdkconfig.SDKConfig) *BaseAPIHandler { + t.Helper() + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ + ID: "model-execution-" + model, + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": model + "@example.com"}, + } + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("manager.Register(): %v", errRegister) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + return NewBaseAPIHandlers(cfg, manager) +} + +func TestExecuteModelCarriesEntryAndExitProtocols(t *testing.T) { + model := "model-execution-nonstream-model" + requestBody := []byte(fmt.Sprintf(`{"model":%q}`, model)) + executor := &modelExecutionCaptureExecutor{ + execute: func(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{ + Payload: []byte(`{"ok":true}`), + Headers: http.Header{ + "X-Upstream": []string{"nonstream"}, + }, + }, nil + }, + } + handler := newModelExecutionHandler(t, model, executor, &sdkconfig.SDKConfig{PassthroughHeaders: true}) + + resp, errMsg := handler.ExecuteModel(context.Background(), ModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: model, + Body: requestBody, + Headers: http.Header{"X-Callback": []string{"nonstream"}}, + Query: url.Values{"q": []string{"callback"}}, + }) + if errMsg != nil { + t.Fatalf("ExecuteModel() error = %+v", errMsg) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if string(resp.Body) != `{"ok":true}` { + t.Fatalf("body = %q, want executor response", resp.Body) + } + if resp.Headers.Get("X-Upstream") != "nonstream" { + t.Fatalf("headers = %#v, want upstream header", resp.Headers) + } + + gotReq, gotOpts := executor.captured() + if gotReq.Model != model { + t.Fatalf("executor model = %q, want %q", gotReq.Model, model) + } + if string(gotReq.Payload) != string(requestBody) { + t.Fatalf("executor payload = %q, want %q", gotReq.Payload, requestBody) + } + if gotOpts.Stream { + t.Fatal("executor stream option = true, want false") + } + if gotOpts.SourceFormat != sdktranslator.FormatOpenAI { + t.Fatalf("SourceFormat = %q, want %q", gotOpts.SourceFormat, sdktranslator.FormatOpenAI) + } + if gotOpts.ResponseFormat != sdktranslator.FormatClaude { + t.Fatalf("ResponseFormat = %q, want %q", gotOpts.ResponseFormat, sdktranslator.FormatClaude) + } + if gotOpts.Metadata[coreexecutor.RequestedModelMetadataKey] != model { + t.Fatalf("requested model metadata = %#v, want %q", gotOpts.Metadata[coreexecutor.RequestedModelMetadataKey], model) + } + if gotOpts.Metadata[modelExecutionMetadataSourceKey] != modelExecutionInternalSource { + t.Fatalf("source metadata = %#v, want %q", gotOpts.Metadata[modelExecutionMetadataSourceKey], modelExecutionInternalSource) + } + if gotOpts.Headers.Get("X-Callback") != "nonstream" { + t.Fatalf("executor headers = %#v, want callback header", gotOpts.Headers) + } + if gotOpts.Query.Get("q") != "callback" { + t.Fatalf("executor query = %#v, want callback query", gotOpts.Query) + } +} + +func TestExecuteModelStream(t *testing.T) { + model := "model-execution-stream-model" + requestBody := []byte(fmt.Sprintf(`{"model":%q,"stream":true}`, model)) + executor := &modelExecutionCaptureExecutor{ + stream: func(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte("stream-one")} + close(chunks) + return &coreexecutor.StreamResult{ + Headers: http.Header{"X-Upstream": []string{"stream"}}, + Chunks: chunks, + }, nil + }, + } + handler := newModelExecutionHandler(t, model, executor, &sdkconfig.SDKConfig{PassthroughHeaders: true}) + + stream, errMsg := handler.ExecuteModelStream(context.Background(), ModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: model, + Stream: true, + Body: requestBody, + Headers: http.Header{"X-Callback": []string{"stream"}}, + }) + if errMsg != nil { + t.Fatalf("ExecuteModelStream() error = %+v", errMsg) + } + if stream.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", stream.StatusCode, http.StatusOK) + } + if stream.Headers.Get("X-Upstream") != "stream" { + t.Fatalf("headers = %#v, want upstream header", stream.Headers) + } + chunk, ok := <-stream.Chunks + if !ok { + t.Fatal("stream chunks closed before payload") + } + if chunk.Err != nil { + t.Fatalf("stream chunk error = %+v", chunk.Err) + } + if string(chunk.Payload) != "stream-one" { + t.Fatalf("stream chunk payload = %q, want stream-one", chunk.Payload) + } + if chunk, ok = <-stream.Chunks; ok { + t.Fatalf("unexpected extra stream chunk: %+v", chunk) + } + + gotReq, gotOpts := executor.captured() + if gotReq.Model != model { + t.Fatalf("executor model = %q, want %q", gotReq.Model, model) + } + if string(gotReq.Payload) != string(requestBody) { + t.Fatalf("executor payload = %q, want %q", gotReq.Payload, requestBody) + } + if !gotOpts.Stream { + t.Fatal("executor stream option = false, want true") + } + if gotOpts.SourceFormat != sdktranslator.FormatOpenAI { + t.Fatalf("SourceFormat = %q, want %q", gotOpts.SourceFormat, sdktranslator.FormatOpenAI) + } + if gotOpts.ResponseFormat != sdktranslator.FormatClaude { + t.Fatalf("ResponseFormat = %q, want %q", gotOpts.ResponseFormat, sdktranslator.FormatClaude) + } + if gotOpts.Metadata[coreexecutor.RequestedModelMetadataKey] != model { + t.Fatalf("requested model metadata = %#v, want %q", gotOpts.Metadata[coreexecutor.RequestedModelMetadataKey], model) + } + if gotOpts.Metadata[modelExecutionMetadataSourceKey] != modelExecutionInternalSource { + t.Fatalf("source metadata = %#v, want %q", gotOpts.Metadata[modelExecutionMetadataSourceKey], modelExecutionInternalSource) + } + if gotOpts.Headers.Get("X-Callback") != "stream" { + t.Fatalf("executor headers = %#v, want callback header", gotOpts.Headers) + } +} + +func TestExecuteModelStreamStartupError(t *testing.T) { + model := "model-execution-stream-startup-error-model" + requestBody := []byte(fmt.Sprintf(`{"model":%q,"stream":true}`, model)) + executor := &modelExecutionCaptureExecutor{ + stream: func(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Err: fmt.Errorf("startup failed")} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil + }, + } + handler := newModelExecutionHandler(t, model, executor, &sdkconfig.SDKConfig{}) + + stream, errMsg := handler.ExecuteModelStream(context.Background(), ModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: model, + Stream: true, + Body: requestBody, + }) + if errMsg == nil { + t.Fatal("ExecuteModelStream() error = nil, want startup error") + } + if errMsg.StatusCode != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusInternalServerError) + } + if errMsg.Error == nil || errMsg.Error.Error() != "startup failed" { + t.Fatalf("error = %v, want startup failed", errMsg.Error) + } + if stream.Chunks != nil { + t.Fatal("stream chunks created for startup error") + } +} + +func TestExecuteModelStreamTerminalError(t *testing.T) { + model := "model-execution-stream-terminal-error-model" + requestBody := []byte(fmt.Sprintf(`{"model":%q,"stream":true}`, model)) + errorHeaders := http.Header{"X-Stream-Error": []string{"terminal"}} + executor := &modelExecutionCaptureExecutor{ + stream: func(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { + chunks := make(chan coreexecutor.StreamChunk, 2) + chunks <- coreexecutor.StreamChunk{Payload: []byte("stream-before-error")} + chunks <- coreexecutor.StreamChunk{Err: modelExecutionStatusHeaderError{ + statusCode: http.StatusTooManyRequests, + message: "rate limited", + headers: errorHeaders, + }} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil + }, + } + handler := newModelExecutionHandler(t, model, executor, &sdkconfig.SDKConfig{}) + + stream, errMsg := handler.ExecuteModelStream(context.Background(), ModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: model, + Stream: true, + Body: requestBody, + }) + if errMsg != nil { + t.Fatalf("ExecuteModelStream() error = %+v", errMsg) + } + + chunk, ok := <-stream.Chunks + if !ok { + t.Fatal("stream chunks closed before payload") + } + if chunk.Err != nil { + t.Fatalf("first stream chunk error = %+v", chunk.Err) + } + if string(chunk.Payload) != "stream-before-error" { + t.Fatalf("first stream chunk payload = %q, want stream-before-error", chunk.Payload) + } + + chunk, ok = <-stream.Chunks + if !ok { + t.Fatal("stream chunks closed before terminal error") + } + if len(chunk.Payload) != 0 { + t.Fatalf("terminal stream chunk payload = %q, want empty", chunk.Payload) + } + if chunk.Err == nil { + t.Fatal("terminal stream chunk error = nil") + } + if chunk.Err.StatusCode != http.StatusTooManyRequests { + t.Fatalf("terminal status = %d, want %d", chunk.Err.StatusCode, http.StatusTooManyRequests) + } + if chunk.Err.Message != "rate limited" { + t.Fatalf("terminal message = %q, want rate limited", chunk.Err.Message) + } + if chunk.Err.Error() != "rate limited" { + t.Fatalf("terminal Error() = %q, want rate limited", chunk.Err.Error()) + } + if chunk.Err.Headers.Get("X-Stream-Error") != "terminal" { + t.Fatalf("terminal headers = %#v, want stream error header", chunk.Err.Headers) + } + if chunk, ok = <-stream.Chunks; ok { + t.Fatalf("unexpected extra stream chunk: %+v", chunk) + } +} + +func TestExecuteModelStreamContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + dataChan := make(chan []byte) + errChan := make(chan *interfaces.ErrorMessage) + chunks := wrapModelExecutionChunks(ctx, dataChan, errChan, nil) + + cancel() + + timeout := time.NewTimer(time.Second) + defer timeout.Stop() + select { + case chunk, ok := <-chunks: + if ok { + t.Fatalf("stream chunks yielded after cancel: %+v", chunk) + } + case <-timeout.C: + t.Fatal("stream chunks did not close after context cancellation") + } +} diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 9f5c4a451..e27a821b9 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -94,12 +94,23 @@ type Options struct { OriginalRequest []byte // SourceFormat identifies the inbound schema. SourceFormat sdktranslator.Format + // ResponseFormat identifies the downstream response schema. + // Empty means responses should use SourceFormat for backward compatibility. + ResponseFormat sdktranslator.Format // Metadata carries extra execution hints shared across selection and executors. Metadata map[string]any // RequestAfterAuthInterceptor runs after credential selection and before executor translation. RequestAfterAuthInterceptor RequestAfterAuthInterceptor } +// ResponseFormatOrSource returns the response target format for an execution. +func ResponseFormatOrSource(opts Options) sdktranslator.Format { + if opts.ResponseFormat != "" { + return opts.ResponseFormat + } + return opts.SourceFormat +} + // Response wraps either a full provider response or metadata for streaming flows. type Response struct { // Payload is the provider response in the executor format. diff --git a/sdk/cliproxy/executor/types_test.go b/sdk/cliproxy/executor/types_test.go new file mode 100644 index 000000000..431272a8c --- /dev/null +++ b/sdk/cliproxy/executor/types_test.go @@ -0,0 +1,26 @@ +package executor + +import ( + "testing" + + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" +) + +func TestResponseFormatOrSourceUsesExplicitResponseFormat(t *testing.T) { + opts := Options{ + SourceFormat: sdktranslator.FormatOpenAI, + ResponseFormat: sdktranslator.FormatClaude, + } + + if got := ResponseFormatOrSource(opts); got != sdktranslator.FormatClaude { + t.Fatalf("ResponseFormatOrSource() = %q, want %q", got, sdktranslator.FormatClaude) + } +} + +func TestResponseFormatOrSourceFallsBackToSourceFormat(t *testing.T) { + opts := Options{SourceFormat: sdktranslator.FormatGemini} + + if got := ResponseFormatOrSource(opts); got != sdktranslator.FormatGemini { + t.Fatalf("ResponseFormatOrSource() = %q, want %q", got, sdktranslator.FormatGemini) + } +} diff --git a/sdk/pluginabi/types.go b/sdk/pluginabi/types.go index 69852234d..fcaf7f184 100644 --- a/sdk/pluginabi/types.go +++ b/sdk/pluginabi/types.go @@ -56,13 +56,17 @@ const ( MethodManagementRegister = "management.register" MethodManagementHandle = "management.handle" - MethodHostHTTPDo = "host.http.do" - MethodHostHTTPDoStream = "host.http.do_stream" - MethodHostHTTPStreamRead = "host.http.stream_read" - MethodHostHTTPStreamClose = "host.http.stream_close" - MethodHostStreamEmit = "host.stream.emit" - MethodHostStreamClose = "host.stream.close" - MethodHostLog = "host.log" + MethodHostHTTPDo = "host.http.do" + MethodHostHTTPDoStream = "host.http.do_stream" + MethodHostHTTPStreamRead = "host.http.stream_read" + MethodHostHTTPStreamClose = "host.http.stream_close" + MethodHostModelExecute = "host.model.execute" + MethodHostModelExecuteStream = "host.model.execute_stream" + MethodHostModelStreamRead = "host.model.stream_read" + MethodHostModelStreamClose = "host.model.stream_close" + MethodHostStreamEmit = "host.stream.emit" + MethodHostStreamClose = "host.stream.close" + MethodHostLog = "host.log" ) type Envelope struct { diff --git a/sdk/pluginabi/types_test.go b/sdk/pluginabi/types_test.go index 7b6ff7da6..3c3f14453 100644 --- a/sdk/pluginabi/types_test.go +++ b/sdk/pluginabi/types_test.go @@ -48,6 +48,18 @@ func TestMethodNamesAreStable(t *testing.T) { if MethodHostHTTPStreamRead != "host.http.stream_read" { t.Fatalf("MethodHostHTTPStreamRead = %q", MethodHostHTTPStreamRead) } + if MethodHostModelExecute != "host.model.execute" { + t.Fatalf("MethodHostModelExecute = %q", MethodHostModelExecute) + } + if MethodHostModelExecuteStream != "host.model.execute_stream" { + t.Fatalf("MethodHostModelExecuteStream = %q", MethodHostModelExecuteStream) + } + if MethodHostModelStreamRead != "host.model.stream_read" { + t.Fatalf("MethodHostModelStreamRead = %q", MethodHostModelStreamRead) + } + if MethodHostModelStreamClose != "host.model.stream_close" { + t.Fatalf("MethodHostModelStreamClose = %q", MethodHostModelStreamClose) + } if MethodExecutorExecuteStream != "executor.execute_stream" { t.Fatalf("MethodExecutorExecuteStream = %q", MethodExecutorExecuteStream) } diff --git a/sdk/pluginapi/types.go b/sdk/pluginapi/types.go index 7ec03c4d9..7aa117132 100644 --- a/sdk/pluginapi/types.go +++ b/sdk/pluginapi/types.go @@ -524,6 +524,68 @@ type HostHTTPClient interface { DoStream(context.Context, HTTPRequest) (HTTPStreamResponse, error) } +// HostModelExecutionRequest describes a model execution request issued through the host. +type HostModelExecutionRequest struct { + // EntryProtocol is the inbound client protocol format. + EntryProtocol string `json:"entry_protocol"` + // ExitProtocol is the target provider protocol format. + ExitProtocol string `json:"exit_protocol"` + // Model is the requested model identifier. + Model string `json:"model"` + // Stream reports whether the request expects streaming output. + Stream bool `json:"stream"` + // Body contains the raw request body. + Body []byte `json:"body"` + // Headers contains request headers. + Headers http.Header `json:"headers"` + // Query contains request query parameters. + Query url.Values `json:"query"` + // Alt carries an alternate route or mode suffix when present. + Alt string `json:"alt"` +} + +// HostModelExecutionResponse describes a non-streaming host model execution response. +type HostModelExecutionResponse struct { + // StatusCode is the model execution HTTP status code. + StatusCode int `json:"status_code"` + // Headers contains response headers. + Headers http.Header `json:"headers"` + // Body contains the raw response body. + Body []byte `json:"body"` +} + +// HostModelStreamResponse describes a streaming host model execution response. +type HostModelStreamResponse struct { + // StatusCode is the model execution HTTP status code. + StatusCode int `json:"status_code"` + // Headers contains response headers. + Headers http.Header `json:"headers"` + // StreamID identifies the host-owned stream for later reads. + StreamID string `json:"stream_id"` +} + +// HostModelStreamReadRequest asks the host to read the next model stream chunk. +type HostModelStreamReadRequest struct { + // StreamID identifies the host-owned stream. + StreamID string `json:"stream_id"` +} + +// HostModelStreamReadResponse returns one model stream chunk or terminal state. +type HostModelStreamReadResponse struct { + // Payload contains the raw stream chunk bytes. + Payload []byte `json:"payload"` + // Error reports a stream error associated with this read. + Error string `json:"error"` + // Done reports whether the stream has ended. + Done bool `json:"done"` +} + +// HostModelStreamCloseRequest asks the host to close a model stream. +type HostModelStreamCloseRequest struct { + // StreamID identifies the host-owned stream. + StreamID string `json:"stream_id"` +} + // HTTPRequest describes an upstream HTTP request issued through the host. type HTTPRequest struct { // Method is the HTTP method. diff --git a/sdk/pluginapi/types_test.go b/sdk/pluginapi/types_test.go index 18725755c..d42470b79 100644 --- a/sdk/pluginapi/types_test.go +++ b/sdk/pluginapi/types_test.go @@ -3,6 +3,8 @@ package pluginapi import ( "context" "encoding/json" + "net/http" + "net/url" "strings" "testing" ) @@ -113,6 +115,153 @@ func TestHostInjectedHTTPClientIsNotEncodedInPluginJSON(t *testing.T) { } } +func TestHostModelTypesPreserveFields(t *testing.T) { + request := HostModelExecutionRequest{ + EntryProtocol: "openai", + ExitProtocol: "claude", + Model: "gpt-test", + Stream: true, + Body: []byte(`{"input":"hello"}`), + Headers: http.Header{"X-Test": []string{"one", "two"}}, + Query: url.Values{"alt": []string{"beta"}}, + Alt: "chat", + } + rawRequest, errMarshalRequest := json.Marshal(request) + if errMarshalRequest != nil { + t.Fatalf("marshal HostModelExecutionRequest: %v", errMarshalRequest) + } + requestJSON := string(rawRequest) + for _, field := range []string{"entry_protocol", "exit_protocol", "model", "stream", "body", "headers", "query", "alt"} { + if !strings.Contains(requestJSON, `"`+field+`"`) { + t.Fatalf("HostModelExecutionRequest JSON missing field %q: %s", field, requestJSON) + } + } + var decodedRequest HostModelExecutionRequest + if errUnmarshalRequest := json.Unmarshal(rawRequest, &decodedRequest); errUnmarshalRequest != nil { + t.Fatalf("unmarshal HostModelExecutionRequest: %v", errUnmarshalRequest) + } + if decodedRequest.EntryProtocol != request.EntryProtocol || + decodedRequest.ExitProtocol != request.ExitProtocol || + decodedRequest.Model != request.Model || + decodedRequest.Stream != request.Stream || + string(decodedRequest.Body) != string(request.Body) || + decodedRequest.Headers.Get("X-Test") != "one" || + decodedRequest.Query.Get("alt") != "beta" || + decodedRequest.Alt != request.Alt { + t.Fatalf("HostModelExecutionRequest round trip = %#v", decodedRequest) + } + if got := decodedRequest.Headers.Values("X-Test"); len(got) != 2 || got[1] != "two" { + t.Fatalf("HostModelExecutionRequest headers = %#v", decodedRequest.Headers) + } + + response := HostModelExecutionResponse{ + StatusCode: http.StatusAccepted, + Headers: http.Header{"Content-Type": []string{"application/json"}}, + Body: []byte(`{"ok":true}`), + } + rawResponse, errMarshalResponse := json.Marshal(response) + if errMarshalResponse != nil { + t.Fatalf("marshal HostModelExecutionResponse: %v", errMarshalResponse) + } + responseJSON := string(rawResponse) + for _, field := range []string{"status_code", "headers", "body"} { + if !strings.Contains(responseJSON, `"`+field+`"`) { + t.Fatalf("HostModelExecutionResponse JSON missing field %q: %s", field, responseJSON) + } + } + var decodedResponse HostModelExecutionResponse + if errUnmarshalResponse := json.Unmarshal(rawResponse, &decodedResponse); errUnmarshalResponse != nil { + t.Fatalf("unmarshal HostModelExecutionResponse: %v", errUnmarshalResponse) + } + if decodedResponse.StatusCode != response.StatusCode || + decodedResponse.Headers.Get("Content-Type") != "application/json" || + string(decodedResponse.Body) != string(response.Body) { + t.Fatalf("HostModelExecutionResponse round trip = %#v", decodedResponse) + } + + streamResponse := HostModelStreamResponse{ + StatusCode: http.StatusOK, + Headers: http.Header{"Content-Type": []string{"text/event-stream"}}, + StreamID: "stream-1", + } + rawStreamResponse, errMarshalStreamResponse := json.Marshal(streamResponse) + if errMarshalStreamResponse != nil { + t.Fatalf("marshal HostModelStreamResponse: %v", errMarshalStreamResponse) + } + streamResponseJSON := string(rawStreamResponse) + for _, field := range []string{"status_code", "headers", "stream_id"} { + if !strings.Contains(streamResponseJSON, `"`+field+`"`) { + t.Fatalf("HostModelStreamResponse JSON missing field %q: %s", field, streamResponseJSON) + } + } + var decodedStreamResponse HostModelStreamResponse + if errUnmarshalStreamResponse := json.Unmarshal(rawStreamResponse, &decodedStreamResponse); errUnmarshalStreamResponse != nil { + t.Fatalf("unmarshal HostModelStreamResponse: %v", errUnmarshalStreamResponse) + } + if decodedStreamResponse.StatusCode != streamResponse.StatusCode || + decodedStreamResponse.Headers.Get("Content-Type") != "text/event-stream" || + decodedStreamResponse.StreamID != streamResponse.StreamID { + t.Fatalf("HostModelStreamResponse round trip = %#v", decodedStreamResponse) + } + + readRequest := HostModelStreamReadRequest{StreamID: "stream-1"} + rawReadRequest, errMarshalReadRequest := json.Marshal(readRequest) + if errMarshalReadRequest != nil { + t.Fatalf("marshal HostModelStreamReadRequest: %v", errMarshalReadRequest) + } + if !strings.Contains(string(rawReadRequest), `"stream_id"`) { + t.Fatalf("HostModelStreamReadRequest JSON missing stream_id: %s", rawReadRequest) + } + var decodedReadRequest HostModelStreamReadRequest + if errUnmarshalReadRequest := json.Unmarshal(rawReadRequest, &decodedReadRequest); errUnmarshalReadRequest != nil { + t.Fatalf("unmarshal HostModelStreamReadRequest: %v", errUnmarshalReadRequest) + } + if decodedReadRequest.StreamID != readRequest.StreamID { + t.Fatalf("HostModelStreamReadRequest round trip = %#v", decodedReadRequest) + } + + readResponse := HostModelStreamReadResponse{ + Payload: []byte("data: test\n\n"), + Error: "temporary stream error", + Done: true, + } + rawReadResponse, errMarshalReadResponse := json.Marshal(readResponse) + if errMarshalReadResponse != nil { + t.Fatalf("marshal HostModelStreamReadResponse: %v", errMarshalReadResponse) + } + readResponseJSON := string(rawReadResponse) + for _, field := range []string{"payload", "error", "done"} { + if !strings.Contains(readResponseJSON, `"`+field+`"`) { + t.Fatalf("HostModelStreamReadResponse JSON missing field %q: %s", field, readResponseJSON) + } + } + var decodedReadResponse HostModelStreamReadResponse + if errUnmarshalReadResponse := json.Unmarshal(rawReadResponse, &decodedReadResponse); errUnmarshalReadResponse != nil { + t.Fatalf("unmarshal HostModelStreamReadResponse: %v", errUnmarshalReadResponse) + } + if string(decodedReadResponse.Payload) != string(readResponse.Payload) || + decodedReadResponse.Error != readResponse.Error || + decodedReadResponse.Done != readResponse.Done { + t.Fatalf("HostModelStreamReadResponse round trip = %#v", decodedReadResponse) + } + + closeRequest := HostModelStreamCloseRequest{StreamID: "stream-1"} + rawCloseRequest, errMarshalCloseRequest := json.Marshal(closeRequest) + if errMarshalCloseRequest != nil { + t.Fatalf("marshal HostModelStreamCloseRequest: %v", errMarshalCloseRequest) + } + if !strings.Contains(string(rawCloseRequest), `"stream_id"`) { + t.Fatalf("HostModelStreamCloseRequest JSON missing stream_id: %s", rawCloseRequest) + } + var decodedCloseRequest HostModelStreamCloseRequest + if errUnmarshalCloseRequest := json.Unmarshal(rawCloseRequest, &decodedCloseRequest); errUnmarshalCloseRequest != nil { + t.Fatalf("unmarshal HostModelStreamCloseRequest: %v", errUnmarshalCloseRequest) + } + if decodedCloseRequest.StreamID != closeRequest.StreamID { + t.Fatalf("HostModelStreamCloseRequest round trip = %#v", decodedCloseRequest) + } +} + func TestSchedulerTypesExposeRoutingFields(t *testing.T) { request := SchedulerPickRequest{ Plugin: Metadata{Name: "scheduler-plugin"},