mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-11 16:53:43 +08:00
- Introduced support for file-backed logging of API requests and responses to handle large payloads efficiently. - Refactored `attachWebsocketLogSources` to `attachRequestLogSources` for broader request and response handling. - Added new methods for appending request/response data to file-backed sources and updated existing logging workflows for compatibility. - Improved cleanup and merge logic for file-backed sources during request processing. - Updated tests to cover newly introduced file-backed logging functionality.
243 lines
6.2 KiB
Go
243 lines
6.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
|
)
|
|
|
|
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
skip bool
|
|
}{
|
|
{
|
|
name: "nil request",
|
|
req: nil,
|
|
skip: true,
|
|
},
|
|
{
|
|
name: "post request should not skip",
|
|
req: &http.Request{
|
|
Method: http.MethodPost,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
},
|
|
skip: false,
|
|
},
|
|
{
|
|
name: "plain get should skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/models"},
|
|
Header: http.Header{},
|
|
},
|
|
skip: true,
|
|
},
|
|
{
|
|
name: "responses websocket upgrade should not skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
Header: http.Header{"Upgrade": []string{"websocket"}},
|
|
},
|
|
skip: false,
|
|
},
|
|
{
|
|
name: "responses get without upgrade should skip",
|
|
req: &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/v1/responses"},
|
|
Header: http.Header{},
|
|
},
|
|
skip: true,
|
|
},
|
|
}
|
|
|
|
for i := range tests {
|
|
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
|
if got != tests[i].skip {
|
|
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestShouldCaptureRequestBody(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
loggerEnabled bool
|
|
req *http.Request
|
|
want bool
|
|
}{
|
|
{
|
|
name: "logger enabled always captures",
|
|
loggerEnabled: true,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("{}")),
|
|
ContentLength: -1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "nil request",
|
|
loggerEnabled: false,
|
|
req: nil,
|
|
want: false,
|
|
},
|
|
{
|
|
name: "small known size json in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("{}")),
|
|
ContentLength: 2,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "large known size skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "unknown size skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: -1,
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "multipart skipped in error-only mode",
|
|
loggerEnabled: false,
|
|
req: &http.Request{
|
|
Body: io.NopCloser(strings.NewReader("x")),
|
|
ContentLength: 1,
|
|
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
|
},
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for i := range tests {
|
|
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
|
if got != tests[i].want {
|
|
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAttachRequestLogSourcesUsesLoggerLogsDir(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logsDir := t.TempDir()
|
|
logger := logging.NewFileRequestLogger(true, logsDir, "", 0)
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/v1/responses", nil)
|
|
c.Request.Header.Set("Upgrade", "websocket")
|
|
|
|
attachRequestLogSources(c, logger, true)
|
|
defer cleanupFileBodySourcesFromContext(c)
|
|
|
|
for _, key := range []string{
|
|
logging.WebsocketTimelineSourceContextKey,
|
|
logging.APIWebsocketTimelineSourceContextKey,
|
|
} {
|
|
value, exists := c.Get(key)
|
|
if !exists {
|
|
t.Fatalf("expected %s source to be attached", key)
|
|
}
|
|
source, ok := value.(*logging.FileBodySource)
|
|
if !ok || source == nil {
|
|
t.Fatalf("%s source type = %T", key, value)
|
|
}
|
|
file, errPart := source.CreatePart("probe")
|
|
if errPart != nil {
|
|
t.Fatalf("CreatePart(%s): %v", key, errPart)
|
|
}
|
|
path := file.Name()
|
|
if errClose := file.Close(); errClose != nil {
|
|
t.Fatalf("close part: %v", errClose)
|
|
}
|
|
if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) {
|
|
t.Fatalf("%s part path %s is not under logs dir %s", key, path, logsDir)
|
|
}
|
|
}
|
|
}
|
|
|
|
func cleanupFileBodySourcesFromContext(c *gin.Context) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
for _, key := range []string{
|
|
logging.WebsocketTimelineSourceContextKey,
|
|
logging.APIWebsocketTimelineSourceContextKey,
|
|
} {
|
|
value, exists := c.Get(key)
|
|
if !exists {
|
|
continue
|
|
}
|
|
if source, ok := value.(*logging.FileBodySource); ok && source != nil {
|
|
_ = source.Cleanup()
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
payload := []byte(`{"model":"test-model","stream":true}`)
|
|
var compressed bytes.Buffer
|
|
encoder, errNewWriter := zstd.NewWriter(&compressed)
|
|
if errNewWriter != nil {
|
|
t.Fatalf("zstd.NewWriter: %v", errNewWriter)
|
|
}
|
|
if _, errWrite := encoder.Write(payload); errWrite != nil {
|
|
t.Fatalf("zstd write: %v", errWrite)
|
|
}
|
|
if errClose := encoder.Close(); errClose != nil {
|
|
t.Fatalf("zstd close: %v", errClose)
|
|
}
|
|
compressedBytes := compressed.Bytes()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes))
|
|
req.Header.Set("Content-Encoding", "zstd")
|
|
c.Request = req
|
|
|
|
info, errCapture := captureRequestInfo(c, true)
|
|
if errCapture != nil {
|
|
t.Fatalf("captureRequestInfo: %v", errCapture)
|
|
}
|
|
if !bytes.Equal(info.Body, payload) {
|
|
t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload))
|
|
}
|
|
|
|
restoredBody, errRead := io.ReadAll(c.Request.Body)
|
|
if errRead != nil {
|
|
t.Fatalf("read restored request body: %v", errRead)
|
|
}
|
|
if !bytes.Equal(restoredBody, compressedBytes) {
|
|
t.Fatal("request body was not restored with the original compressed bytes")
|
|
}
|
|
}
|