Files
CLIProxyAPI/internal/api/middleware/request_logging.go
Luis Pater 82c9e0de58 feat(api, watcher): add zstd decoding for request logs and payload diff support
- Added `zstd` decoding support in request logging, including helper functions to process `Content-Encoding` headers.
- Enhanced config diff logic to compare payload-specific rules and track changes in payload configurations.
- Added tests to validate `zstd` decoding and payload diff behavior.
2026-05-16 13:00:32 +08:00

220 lines
5.9 KiB
Go

// Package middleware provides HTTP middleware components for the CLI Proxy API server.
// This file contains the request logging middleware that captures comprehensive
// request and response data when enabled through configuration.
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
c.Next()
return
}
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
path := c.Request.URL.Path
if !shouldLogRequest(path) {
c.Next()
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
c.Next()
return
}
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
// Process the request
c.Next()
// Finalize logging after request processing
if err = wrapper.Finalize(c); err != nil {
// Log error but don't interrupt the response
// In a real implementation, you might want to use a proper logger here
}
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
if maskedQuery != "" {
url += "?" + maskedQuery
}
// Capture method
method := c.Request.Method
// Capture headers
headers := make(map[string][]string)
for key, values := range c.Request.Header {
headers[key] = values
}
// Capture request body
var body []byte
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
// Restore the body for the actual request processing
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
body = decodeCapturedRequestBodyForLog(bodyBytes, c.Request.Header.Get("Content-Encoding"))
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
Timestamp: time.Now(),
}, nil
}
func decodeCapturedRequestBodyForLog(raw []byte, encoding string) []byte {
if len(raw) == 0 {
return raw
}
decoded, errDecode := decodeCapturedRequestBody(raw, encoding)
if errDecode != nil {
return raw
}
return decoded
}
func decodeCapturedRequestBody(raw []byte, encoding string) ([]byte, error) {
encoding = strings.TrimSpace(encoding)
if encoding == "" || strings.EqualFold(encoding, "identity") {
return raw, nil
}
parts := strings.Split(encoding, ",")
body := raw
for i := len(parts) - 1; i >= 0; i-- {
enc := strings.ToLower(strings.TrimSpace(parts[i]))
switch enc {
case "", "identity":
continue
case "zstd":
decoded, errDecode := decodeCapturedZstdRequestBody(body)
if errDecode != nil {
return nil, errDecode
}
body = decoded
default:
return nil, fmt.Errorf("unsupported request content encoding: %s", enc)
}
}
return body, nil
}
func decodeCapturedZstdRequestBody(raw []byte) ([]byte, error) {
decoder, errNewReader := zstd.NewReader(bytes.NewReader(raw))
if errNewReader != nil {
return nil, fmt.Errorf("failed to create zstd request decoder: %w", errNewReader)
}
defer decoder.Close()
decoded, errRead := io.ReadAll(decoder)
if errRead != nil {
return nil, fmt.Errorf("failed to decode zstd request body: %w", errRead)
}
return decoded, nil
}
// shouldLogRequest determines whether the request should be logged.
// It skips management endpoints to avoid leaking secrets but allows
// all other routes, including module-provided ones, to honor request-log.
func shouldLogRequest(path string) bool {
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
return false
}
if strings.HasPrefix(path, "/api") {
return strings.HasPrefix(path, "/api/provider")
}
return true
}