diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index 29dc0ea0b..e1cde111c 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -96,7 +96,7 @@ func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) { // Parameters: // - c: The Gin context containing the HTTP request and response func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) { - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) // If data retrieval fails, return a 400 Bad Request error. if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -151,7 +151,7 @@ func shouldTreatAsResponsesFormat(rawJSON []byte) bool { // Parameters: // - c: The Gin context containing the HTTP request and response func (h *OpenAIAPIHandler) Completions(c *gin.Context) { - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) // If data retrieval fails, return a 400 Bad Request error. if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go index 6e6e8ef6f..72f06093c 100644 --- a/sdk/api/handlers/openai/openai_images_handlers.go +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -204,7 +204,7 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { return } - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ Error: handlers.ErrorDetail{ @@ -435,7 +435,7 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { } func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ Error: handlers.ErrorDetail{ diff --git a/sdk/api/handlers/openai/openai_responses_compact_test.go b/sdk/api/handlers/openai/openai_responses_compact_test.go index 48b7e3bbd..4d3b4574d 100644 --- a/sdk/api/handlers/openai/openai_responses_compact_test.go +++ b/sdk/api/handlers/openai/openai_responses_compact_test.go @@ -1,6 +1,7 @@ package openai import ( + "bytes" "context" "errors" "net/http" @@ -9,6 +10,7 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" @@ -118,3 +120,55 @@ func TestOpenAIResponsesCompactExecute(t *testing.T) { t.Fatalf("body = %s", resp.Body.String()) } } + +func TestOpenAIResponsesCompactDecodesZstdRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + executor := &compactCaptureExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth := &coreauth.Auth{ID: "auth3", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.POST("/v1/responses/compact", h.Compact) + + var compressed bytes.Buffer + encoder, err := zstd.NewWriter(&compressed) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + if _, errWrite := encoder.Write([]byte(`{"model":"test-model","input":"hello"}`)); errWrite != nil { + t.Fatalf("zstd write: %v", errWrite) + } + if errClose := encoder.Close(); errClose != nil { + t.Fatalf("zstd close: %v", errClose) + } + + req := httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(compressed.Bytes())) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Encoding", "zstd") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.calls != 1 { + t.Fatalf("executor calls = %d, want 1", executor.calls) + } + if executor.alt != "responses/compact" { + t.Fatalf("alt = %q, want %q", executor.alt, "responses/compact") + } + if strings.TrimSpace(resp.Body.String()) != `{"ok":true}` { + t.Fatalf("body = %s", resp.Body.String()) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 5b2c006a3..e9063b86d 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -370,7 +370,7 @@ func (h *OpenAIResponsesAPIHandler) OpenAIResponsesModels(c *gin.Context) { // Parameters: // - c: The Gin context containing the HTTP request and response func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) { - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) // If data retrieval fails, return a 400 Bad Request error. if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ @@ -393,7 +393,7 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) { } func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) { - rawJSON, err := c.GetRawData() + rawJSON, err := handlers.ReadRequestBody(c) if err != nil { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ Error: handlers.ErrorDetail{ diff --git a/sdk/api/handlers/request_body.go b/sdk/api/handlers/request_body.go new file mode 100644 index 000000000..568872d2b --- /dev/null +++ b/sdk/api/handlers/request_body.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" +) + +// ReadRequestBody reads the incoming request body and decodes supported +// Content-Encoding values before handlers inspect JSON fields. +func ReadRequestBody(c *gin.Context) ([]byte, error) { + raw, err := c.GetRawData() + if err != nil { + return nil, err + } + + encoding := "" + if c != nil && c.Request != nil { + encoding = strings.TrimSpace(c.Request.Header.Get("Content-Encoding")) + } + if encoding == "" || strings.EqualFold(encoding, "identity") { + return raw, nil + } + + decoded, err := decodeRequestBody(raw, encoding) + if err != nil { + if json.Valid(raw) { + return raw, nil + } + return nil, err + } + return decoded, nil +} + +func decodeRequestBody(raw []byte, encoding string) ([]byte, error) { + parts := strings.Split(encoding, ",") + body := raw + for i := len(parts) - 1; i >= 0; i-- { + enc := strings.ToLower(strings.TrimSpace(parts[i])) + switch enc { + case "", "identity": + continue + case "zstd": + decoded, err := decodeZstdRequestBody(body) + if err != nil { + return nil, err + } + body = decoded + default: + return nil, fmt.Errorf("unsupported request content encoding: %s", enc) + } + } + return body, nil +} + +func decodeZstdRequestBody(raw []byte) ([]byte, error) { + decoder, err := zstd.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, fmt.Errorf("failed to create zstd request decoder: %w", err) + } + defer decoder.Close() + + decoded, err := io.ReadAll(decoder) + if err != nil { + return nil, fmt.Errorf("failed to decode zstd request body: %w", err) + } + return decoded, nil +}