diff --git a/internal/pluginhost/adapters.go b/internal/pluginhost/adapters.go index 5be003588..a5801e224 100644 --- a/internal/pluginhost/adapters.go +++ b/internal/pluginhost/adapters.go @@ -1399,7 +1399,7 @@ func executorResponseTranslatorExists(from, to sdktranslator.Format) bool { if from == "" || to == "" || from == to { return true } - return sdktranslator.HasResponseTransformer(to, from) + return sdktranslator.HasStreamResponseTransformer(to, from) } func (a *executorAdapter) translateExecutorResponse(ctx context.Context, prepared preparedExecutorCall, payload []byte, stream bool, param *any) []byte { diff --git a/internal/pluginhost/adapters_test.go b/internal/pluginhost/adapters_test.go index e718b57de..a9db914ee 100644 --- a/internal/pluginhost/adapters_test.go +++ b/internal/pluginhost/adapters_test.go @@ -78,6 +78,32 @@ func TestPluginModelInfoToRegistryModelInfoClonesThinkingAndSlices(t *testing.T) } } +func TestExecutorResponseTranslatorExistsRequiresStreamTransform(t *testing.T) { + outputFormat := sdktranslator.Format("plugin-output-non-stream-only") + requestedFormat := sdktranslator.Format("client-output-non-stream-only") + sdktranslator.Register(requestedFormat, outputFormat, nil, sdktranslator.ResponseTransform{ + NonStream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { + return rawJSON + }, + }) + + if executorResponseTranslatorExists(outputFormat, requestedFormat) { + t.Fatal("non-stream-only response transformer was accepted for stream executor output") + } + + streamOutputFormat := sdktranslator.Format("plugin-output-stream") + streamRequestedFormat := sdktranslator.Format("client-output-stream") + sdktranslator.Register(streamRequestedFormat, streamOutputFormat, nil, sdktranslator.ResponseTransform{ + Stream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { + return [][]byte{rawJSON} + }, + }) + + if !executorResponseTranslatorExists(streamOutputFormat, streamRequestedFormat) { + t.Fatal("stream response transformer was not accepted for stream executor output") + } +} + func TestRegisterModelsRegistersProviderModelsAndClientID(t *testing.T) { modelRegistry := newFakeModelRegistry() host := newHostWithRecords(capabilityRecord{ diff --git a/sdk/translator/helpers.go b/sdk/translator/helpers.go index db38d745b..80c83d529 100644 --- a/sdk/translator/helpers.go +++ b/sdk/translator/helpers.go @@ -17,6 +17,16 @@ func HasResponseTransformerByFormatName(from, to Format) bool { return HasResponseTransformer(from, to) } +// HasStreamResponseTransformerByFormatName reports whether a stream response translator exists between two schemas. +func HasStreamResponseTransformerByFormatName(from, to Format) bool { + return HasStreamResponseTransformer(from, to) +} + +// HasNonStreamResponseTransformerByFormatName reports whether a non-stream response translator exists between two schemas. +func HasNonStreamResponseTransformerByFormatName(from, to Format) bool { + return HasNonStreamResponseTransformer(from, to) +} + // TranslateStreamByFormatName converts streaming responses between schemas by their string identifiers. func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { return TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) diff --git a/sdk/translator/registry.go b/sdk/translator/registry.go index ac07107b8..ad4d351db 100644 --- a/sdk/translator/registry.go +++ b/sdk/translator/registry.go @@ -107,7 +107,33 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool { defer r.mu.RUnlock() if byTarget, ok := r.responses[from]; ok { - if _, isOk := byTarget[to]; isOk { + if fn, isOk := byTarget[to]; isOk && hasAnyResponseTransform(fn) { + return true + } + } + return false +} + +// HasStreamResponseTransformer indicates whether a streaming response translator exists. +func (r *Registry) HasStreamResponseTransformer(from, to Format) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + if byTarget, ok := r.responses[from]; ok { + if fn, isOk := byTarget[to]; isOk && fn.Stream != nil { + return true + } + } + return false +} + +// HasNonStreamResponseTransformer indicates whether a non-streaming response translator exists. +func (r *Registry) HasNonStreamResponseTransformer(from, to Format) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + if byTarget, ok := r.responses[from]; ok { + if fn, isOk := byTarget[to]; isOk && fn.NonStream != nil { return true } } @@ -117,9 +143,9 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool { // TranslateStream applies the registered streaming response translator. func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { r.mu.RLock() - var fn ResponseTransform + var stream ResponseStreamTransform if byTarget, ok := r.responses[to]; ok { - fn = byTarget[from] + stream = byTarget[from].Stream } hooks := r.hooks r.mu.RUnlock() @@ -130,14 +156,16 @@ func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model s } var outputs [][]byte - if fn.Stream != nil { - outputs = fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, body, param) + usedNativeTransform := false + if stream != nil { + usedNativeTransform = true + outputs = stream(ctx, model, originalRequestRawJSON, requestRawJSON, body, param) } else if hooks != nil { if translated, ok := hooks.TranslateResponse(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, body, true); ok { outputs = [][]byte{translated} } } - if outputs == nil { + if outputs == nil && !usedNativeTransform { outputs = [][]byte{body} } if hooks != nil { @@ -220,6 +248,16 @@ func HasResponseTransformer(from, to Format) bool { return defaultRegistry.HasResponseTransformer(from, to) } +// HasStreamResponseTransformer inspects the default registry for a streaming response translator. +func HasStreamResponseTransformer(from, to Format) bool { + return defaultRegistry.HasStreamResponseTransformer(from, to) +} + +// HasNonStreamResponseTransformer inspects the default registry for a non-streaming response translator. +func HasNonStreamResponseTransformer(from, to Format) bool { + return defaultRegistry.HasNonStreamResponseTransformer(from, to) +} + // TranslateStream is a helper on the default registry. func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) @@ -234,3 +272,7 @@ func TranslateNonStream(ctx context.Context, from, to Format, model string, orig func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte { return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON) } + +func hasAnyResponseTransform(fn ResponseTransform) bool { + return fn.Stream != nil || fn.NonStream != nil || fn.TokenCount != nil +} diff --git a/sdk/translator/registry_test.go b/sdk/translator/registry_test.go index 0b01053b4..f154cb397 100644 --- a/sdk/translator/registry_test.go +++ b/sdk/translator/registry_test.go @@ -164,6 +164,70 @@ func TestHasRequestTransformer(t *testing.T) { } } +func TestHasResponseTransformerIgnoresEmptyRegistration(t *testing.T) { + r := NewRegistry() + from := Format("from") + to := Format("to") + + r.Register(from, to, func(model string, rawJSON []byte, stream bool) []byte { + return rawJSON + }, ResponseTransform{}) + + if r.HasResponseTransformer(from, to) { + t.Fatal("empty response transform was reported as a response transformer") + } + if r.HasStreamResponseTransformer(from, to) { + t.Fatal("empty response transform was reported as a stream response transformer") + } + if r.HasNonStreamResponseTransformer(from, to) { + t.Fatal("empty response transform was reported as a non-stream response transformer") + } +} + +func TestHasResponseTransformerChecksConcreteResponseKinds(t *testing.T) { + ctx := context.Background() + r := NewRegistry() + from := Format("from") + streamOnlyTo := Format("stream-to") + nonStreamOnlyTo := Format("non-stream-to") + + r.Register(from, streamOnlyTo, nil, ResponseTransform{ + Stream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { + return [][]byte{rawJSON} + }, + }) + r.Register(from, nonStreamOnlyTo, nil, ResponseTransform{ + NonStream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { + return rawJSON + }, + }) + + if !r.HasResponseTransformer(from, streamOnlyTo) { + t.Fatal("stream response transform was not reported as a response transformer") + } + if !r.HasStreamResponseTransformer(from, streamOnlyTo) { + t.Fatal("stream response transform was not reported as a stream response transformer") + } + if r.HasNonStreamResponseTransformer(from, streamOnlyTo) { + t.Fatal("stream-only transform was reported as a non-stream response transformer") + } + + if !r.HasResponseTransformer(from, nonStreamOnlyTo) { + t.Fatal("non-stream response transform was not reported as a response transformer") + } + if r.HasStreamResponseTransformer(from, nonStreamOnlyTo) { + t.Fatal("non-stream-only transform was reported as a stream response transformer") + } + if !r.HasNonStreamResponseTransformer(from, nonStreamOnlyTo) { + t.Fatal("non-stream response transform was not reported as a non-stream response transformer") + } + + got := r.TranslateStream(ctx, streamOnlyTo, from, "model", nil, nil, []byte(`data: {"ok":true}`), nil) + if len(got) != 1 || string(got[0]) != `data: {"ok":true}` { + t.Fatalf("stream transform output = %q", got) + } +} + func TestTranslateRequest_PluginTranslatorOnlyWhenNativeMissing(t *testing.T) { from := Format("from") to := Format("to") @@ -243,6 +307,50 @@ func TestTranslateNonStream_PluginTranslatorOnlyWhenNativeMissing(t *testing.T) } } +func TestTranslateStream_NativeEmptyOutputSuppressesRawFallback(t *testing.T) { + ctx := context.Background() + from := Format("client") + to := Format("upstream") + + r := NewRegistry() + r.Register(to, from, nil, ResponseTransform{ + Stream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { + return nil + }, + }) + + got := r.TranslateStream(ctx, from, to, "model", nil, nil, []byte(`data: {"raw":true}`), nil) + if len(got) != 0 { + t.Fatalf("native stream transformer returned empty output, got raw fallback %q", got) + } +} + +func TestTranslateStream_PluginTranslatorUsedWhenNativeStreamMissing(t *testing.T) { + ctx := context.Background() + from := Format("client") + to := Format("upstream") + + r := NewRegistry() + hooks := &fakePluginHooks{ + responseTranslateBody: []byte(`data: {"plugin":true}`), + responseTranslateOK: true, + } + r.SetPluginHooks(hooks) + r.Register(to, from, nil, ResponseTransform{ + NonStream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { + return []byte(`{"native-non-stream":true}`) + }, + }) + + got := r.TranslateStream(ctx, from, to, "model", nil, nil, []byte(`data: {"raw":true}`), nil) + if len(got) != 1 || string(got[0]) != `data: {"plugin":true}` { + t.Fatalf("plugin stream translator was not used, got %q", got) + } + if !hasCall(hooks.calls, "translate-response") { + t.Fatal("plugin response translator was not called when native stream transformer was missing") + } +} + func TestPluginNormalizersChainAfterNative(t *testing.T) { ctx := context.Background() r := NewRegistry()