Files
CLIProxyAPI/internal/api/middleware/request_logging_test.go
Luis Pater 5753d1a089 feat(logging): enable file-backed request/response sources for enhanced API logging
- Introduced support for file-backed logging of API requests and responses to handle large payloads efficiently.
- Refactored `attachWebsocketLogSources` to `attachRequestLogSources` for broader request and response handling.
- Added new methods for appending request/response data to file-backed sources and updated existing logging workflows for compatibility.
- Improved cleanup and merge logic for file-backed sources during request processing.
- Updated tests to cover newly introduced file-backed logging functionality.
2026-06-05 01:48:05 +08:00

243 lines
6.2 KiB
Go

package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}
func TestAttachRequestLogSourcesUsesLoggerLogsDir(t *testing.T) {
gin.SetMode(gin.TestMode)
logsDir := t.TempDir()
logger := logging.NewFileRequestLogger(true, logsDir, "", 0)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/responses", nil)
c.Request.Header.Set("Upgrade", "websocket")
attachRequestLogSources(c, logger, true)
defer cleanupFileBodySourcesFromContext(c)
for _, key := range []string{
logging.WebsocketTimelineSourceContextKey,
logging.APIWebsocketTimelineSourceContextKey,
} {
value, exists := c.Get(key)
if !exists {
t.Fatalf("expected %s source to be attached", key)
}
source, ok := value.(*logging.FileBodySource)
if !ok || source == nil {
t.Fatalf("%s source type = %T", key, value)
}
file, errPart := source.CreatePart("probe")
if errPart != nil {
t.Fatalf("CreatePart(%s): %v", key, errPart)
}
path := file.Name()
if errClose := file.Close(); errClose != nil {
t.Fatalf("close part: %v", errClose)
}
if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) {
t.Fatalf("%s part path %s is not under logs dir %s", key, path, logsDir)
}
}
}
func cleanupFileBodySourcesFromContext(c *gin.Context) {
if c == nil {
return
}
for _, key := range []string{
logging.WebsocketTimelineSourceContextKey,
logging.APIWebsocketTimelineSourceContextKey,
} {
value, exists := c.Get(key)
if !exists {
continue
}
if source, ok := value.(*logging.FileBodySource); ok && source != nil {
_ = source.Cleanup()
}
}
}
func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) {
gin.SetMode(gin.TestMode)
payload := []byte(`{"model":"test-model","stream":true}`)
var compressed bytes.Buffer
encoder, errNewWriter := zstd.NewWriter(&compressed)
if errNewWriter != nil {
t.Fatalf("zstd.NewWriter: %v", errNewWriter)
}
if _, errWrite := encoder.Write(payload); errWrite != nil {
t.Fatalf("zstd write: %v", errWrite)
}
if errClose := encoder.Close(); errClose != nil {
t.Fatalf("zstd close: %v", errClose)
}
compressedBytes := compressed.Bytes()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes))
req.Header.Set("Content-Encoding", "zstd")
c.Request = req
info, errCapture := captureRequestInfo(c, true)
if errCapture != nil {
t.Fatalf("captureRequestInfo: %v", errCapture)
}
if !bytes.Equal(info.Body, payload) {
t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload))
}
restoredBody, errRead := io.ReadAll(c.Request.Body)
if errRead != nil {
t.Fatalf("read restored request body: %v", errRead)
}
if !bytes.Equal(restoredBody, compressedBytes) {
t.Fatal("request body was not restored with the original compressed bytes")
}
}