mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-01 12:22:31 +08:00
feat(api): add request body decoding with Content-Encoding support
- Introduced `ReadRequestBody` helper function to support decoding request bodies based on "Content-Encoding" (e.g., `zstd`). - Replaced `c.GetRawData()` with `ReadRequestBody` across handlers to enable decoding. - Added test case to validate `zstd` decoding for compact responses.
This commit is contained in:
@@ -96,7 +96,7 @@ func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -151,7 +151,7 @@ func shouldTreatAsResponsesFormat(rawJSON []byte) bool {
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
|
||||
@@ -204,7 +204,7 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
@@ -435,7 +435,7 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
@@ -118,3 +120,55 @@ func TestOpenAIResponsesCompactExecute(t *testing.T) {
|
||||
t.Fatalf("body = %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesCompactDecodesZstdRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
executor := &compactCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth := &coreauth.Auth{ID: "auth3", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/responses/compact", h.Compact)
|
||||
|
||||
var compressed bytes.Buffer
|
||||
encoder, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
if _, errWrite := encoder.Write([]byte(`{"model":"test-model","input":"hello"}`)); errWrite != nil {
|
||||
t.Fatalf("zstd write: %v", errWrite)
|
||||
}
|
||||
if errClose := encoder.Close(); errClose != nil {
|
||||
t.Fatalf("zstd close: %v", errClose)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if executor.calls != 1 {
|
||||
t.Fatalf("executor calls = %d, want 1", executor.calls)
|
||||
}
|
||||
if executor.alt != "responses/compact" {
|
||||
t.Fatalf("alt = %q, want %q", executor.alt, "responses/compact")
|
||||
}
|
||||
if strings.TrimSpace(resp.Body.String()) != `{"ok":true}` {
|
||||
t.Fatalf("body = %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,7 +370,7 @@ func (h *OpenAIResponsesAPIHandler) OpenAIResponsesModels(c *gin.Context) {
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -393,7 +393,7 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
rawJSON, err := handlers.ReadRequestBody(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
73
sdk/api/handlers/request_body.go
Normal file
73
sdk/api/handlers/request_body.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// ReadRequestBody reads the incoming request body and decodes supported
|
||||
// Content-Encoding values before handlers inspect JSON fields.
|
||||
func ReadRequestBody(c *gin.Context) ([]byte, error) {
|
||||
raw, err := c.GetRawData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoding := ""
|
||||
if c != nil && c.Request != nil {
|
||||
encoding = strings.TrimSpace(c.Request.Header.Get("Content-Encoding"))
|
||||
}
|
||||
if encoding == "" || strings.EqualFold(encoding, "identity") {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
decoded, err := decodeRequestBody(raw, encoding)
|
||||
if err != nil {
|
||||
if json.Valid(raw) {
|
||||
return raw, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func decodeRequestBody(raw []byte, encoding string) ([]byte, error) {
|
||||
parts := strings.Split(encoding, ",")
|
||||
body := raw
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
enc := strings.ToLower(strings.TrimSpace(parts[i]))
|
||||
switch enc {
|
||||
case "", "identity":
|
||||
continue
|
||||
case "zstd":
|
||||
decoded, err := decodeZstdRequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body = decoded
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported request content encoding: %s", enc)
|
||||
}
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func decodeZstdRequestBody(raw []byte) ([]byte, error) {
|
||||
decoder, err := zstd.NewReader(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create zstd request decoder: %w", err)
|
||||
}
|
||||
defer decoder.Close()
|
||||
|
||||
decoded, err := io.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode zstd request body: %w", err)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
Reference in New Issue
Block a user