package middleware import ( "bytes" "io" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" ) func TestShouldSkipMethodForRequestLogging(t *testing.T) { tests := []struct { name string req *http.Request skip bool }{ { name: "nil request", req: nil, skip: true, }, { name: "post request should not skip", req: &http.Request{ Method: http.MethodPost, URL: &url.URL{Path: "/v1/responses"}, }, skip: false, }, { name: "plain get should skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/models"}, Header: http.Header{}, }, skip: true, }, { name: "responses websocket upgrade should not skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/responses"}, Header: http.Header{"Upgrade": []string{"websocket"}}, }, skip: false, }, { name: "responses get without upgrade should skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/responses"}, Header: http.Header{}, }, skip: true, }, } for i := range tests { got := shouldSkipMethodForRequestLogging(tests[i].req) if got != tests[i].skip { t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip) } } } func TestShouldCaptureRequestBody(t *testing.T) { tests := []struct { name string loggerEnabled bool req *http.Request want bool }{ { name: "logger enabled always captures", loggerEnabled: true, req: &http.Request{ Body: io.NopCloser(strings.NewReader("{}")), ContentLength: -1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: true, }, { name: "nil request", loggerEnabled: false, req: nil, want: false, }, { name: "small known size json in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("{}")), ContentLength: 2, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: true, }, { name: "large known size skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: false, }, { name: "unknown size skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: -1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: false, }, { name: "multipart skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: 1, Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}}, }, want: false, }, } for i := range tests { got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req) if got != tests[i].want { t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want) } } } func TestAttachRequestLogSourcesUsesLoggerLogsDir(t *testing.T) { gin.SetMode(gin.TestMode) logsDir := t.TempDir() logger := logging.NewFileRequestLogger(true, logsDir, "", 0) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/v1/responses", nil) c.Request.Header.Set("Upgrade", "websocket") attachRequestLogSources(c, logger, true) defer cleanupFileBodySourcesFromContext(c) for _, key := range []string{ logging.WebsocketTimelineSourceContextKey, logging.APIWebsocketTimelineSourceContextKey, } { value, exists := c.Get(key) if !exists { t.Fatalf("expected %s source to be attached", key) } source, ok := value.(*logging.FileBodySource) if !ok || source == nil { t.Fatalf("%s source type = %T", key, value) } file, errPart := source.CreatePart("probe") if errPart != nil { t.Fatalf("CreatePart(%s): %v", key, errPart) } path := file.Name() if errClose := file.Close(); errClose != nil { t.Fatalf("close part: %v", errClose) } if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) { t.Fatalf("%s part path %s is not under logs dir %s", key, path, logsDir) } } } func cleanupFileBodySourcesFromContext(c *gin.Context) { if c == nil { return } for _, key := range []string{ logging.WebsocketTimelineSourceContextKey, logging.APIWebsocketTimelineSourceContextKey, } { value, exists := c.Get(key) if !exists { continue } if source, ok := value.(*logging.FileBodySource); ok && source != nil { _ = source.Cleanup() } } } func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) { gin.SetMode(gin.TestMode) payload := []byte(`{"model":"test-model","stream":true}`) var compressed bytes.Buffer encoder, errNewWriter := zstd.NewWriter(&compressed) if errNewWriter != nil { t.Fatalf("zstd.NewWriter: %v", errNewWriter) } if _, errWrite := encoder.Write(payload); errWrite != nil { t.Fatalf("zstd write: %v", errWrite) } if errClose := encoder.Close(); errClose != nil { t.Fatalf("zstd close: %v", errClose) } compressedBytes := compressed.Bytes() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes)) req.Header.Set("Content-Encoding", "zstd") c.Request = req info, errCapture := captureRequestInfo(c, true) if errCapture != nil { t.Fatalf("captureRequestInfo: %v", errCapture) } if !bytes.Equal(info.Body, payload) { t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload)) } restoredBody, errRead := io.ReadAll(c.Request.Body) if errRead != nil { t.Fatalf("read restored request body: %v", errRead) } if !bytes.Equal(restoredBody, compressedBytes) { t.Fatal("request body was not restored with the original compressed bytes") } }