mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-23 12:50:10 +08:00
- Added `zstd` decoding support in request logging, including helper functions to process `Content-Encoding` headers. - Enhanced config diff logic to compare payload-specific rules and track changes in payload configurations. - Added tests to validate `zstd` decoding and payload diff behavior.
184 lines
4.6 KiB
Go
184 lines
4.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
skip bool
|
|
}{
|
|
{
|
|
name: "nil request",
|
|
req: nil,
|
|
skip: true,
|
|
},
|
|
{
|
|
name: "post request should not skip",
|
|
req: &http.Request{
|
|
Method: http.MethodPost,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
},
|
|
skip: false,
|
|
},
|
|
{
|
|
name: "plain get should skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/models"},
|
|
Header: http.Header{},
|
|
},
|
|
skip: true,
|
|
},
|
|
{
|
|
name: "responses websocket upgrade should not skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
Header: http.Header{"Upgrade": []string{"websocket"}},
|
|
},
|
|
skip: false,
|
|
},
|
|
{
|
|
name: "responses get without upgrade should skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
Header: http.Header{},
|
|
},
|
|
skip: true,
|
|
},
|
|
}
|
|
|
|
for i := range tests {
|
|
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
|
if got != tests[i].skip {
|
|
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestShouldCaptureRequestBody(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
loggerEnabled bool
|
|
req *http.Request
|
|
want bool
|
|
}{
|
|
{
|
|
name: "logger enabled always captures",
|
|
loggerEnabled: true,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("{}")),
|
|
ContentLength: -1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "nil request",
|
|
loggerEnabled: false,
|
|
req: nil,
|
|
want: false,
|
|
},
|
|
{
|
|
name: "small known size json in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("{}")),
|
|
ContentLength: 2,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "large known size skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "unknown size skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: -1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "multipart skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: 1,
|
|
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
|
},
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for i := range tests {
|
|
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
|
if got != tests[i].want {
|
|
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
payload := []byte(`{"model":"test-model","stream":true}`)
|
|
var compressed bytes.Buffer
|
|
encoder, errNewWriter := zstd.NewWriter(&compressed)
|
|
if errNewWriter != nil {
|
|
t.Fatalf("zstd.NewWriter: %v", errNewWriter)
|
|
}
|
|
if _, errWrite := encoder.Write(payload); errWrite != nil {
|
|
t.Fatalf("zstd write: %v", errWrite)
|
|
}
|
|
if errClose := encoder.Close(); errClose != nil {
|
|
t.Fatalf("zstd close: %v", errClose)
|
|
}
|
|
compressedBytes := compressed.Bytes()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes))
|
|
req.Header.Set("Content-Encoding", "zstd")
|
|
c.Request = req
|
|
|
|
info, errCapture := captureRequestInfo(c, true)
|
|
if errCapture != nil {
|
|
t.Fatalf("captureRequestInfo: %v", errCapture)
|
|
}
|
|
if !bytes.Equal(info.Body, payload) {
|
|
t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload))
|
|
}
|
|
|
|
restoredBody, errRead := io.ReadAll(c.Request.Body)
|
|
if errRead != nil {
|
|
t.Fatalf("read restored request body: %v", errRead)
|
|
}
|
|
if !bytes.Equal(restoredBody, compressedBytes) {
|
|
t.Fatal("request body was not restored with the original compressed bytes")
|
|
}
|
|
}
|