mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-23 21:01:01 +08:00
- Added `zstd` decoding support in request logging, including helper functions to process `Content-Encoding` headers. - Enhanced config diff logic to compare payload-specific rules and track changes in payload configurations. - Added tests to validate `zstd` decoding and payload diff behavior.
220 lines
5.9 KiB
Go
220 lines
5.9 KiB
Go
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
|
|
// This file contains the request logging middleware that captures comprehensive
|
|
// request and response data when enabled through configuration.
|
|
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
|
)
|
|
|
|
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
|
|
|
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
|
// It captures detailed information about the request and response, including headers and body,
|
|
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
|
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
|
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if logger == nil {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
if shouldSkipMethodForRequestLogging(c.Request) {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
path := c.Request.URL.Path
|
|
if !shouldLogRequest(path) {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
loggerEnabled := logger.IsEnabled()
|
|
|
|
// Capture request information
|
|
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
|
if err != nil {
|
|
// Log error but continue processing
|
|
// In a real implementation, you might want to use a proper logger here
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Create response writer wrapper
|
|
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
|
if !loggerEnabled {
|
|
wrapper.logOnErrorOnly = true
|
|
}
|
|
c.Writer = wrapper
|
|
|
|
// Process the request
|
|
c.Next()
|
|
|
|
// Finalize logging after request processing
|
|
if err = wrapper.Finalize(c); err != nil {
|
|
// Log error but don't interrupt the response
|
|
// In a real implementation, you might want to use a proper logger here
|
|
}
|
|
}
|
|
}
|
|
|
|
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
|
if req == nil {
|
|
return true
|
|
}
|
|
if req.Method != http.MethodGet {
|
|
return false
|
|
}
|
|
return !isResponsesWebsocketUpgrade(req)
|
|
}
|
|
|
|
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
|
if req == nil || req.URL == nil {
|
|
return false
|
|
}
|
|
if req.URL.Path != "/v1/responses" {
|
|
return false
|
|
}
|
|
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
|
}
|
|
|
|
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
|
if loggerEnabled {
|
|
return true
|
|
}
|
|
if req == nil || req.Body == nil {
|
|
return false
|
|
}
|
|
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
|
return false
|
|
}
|
|
if req.ContentLength <= 0 {
|
|
return false
|
|
}
|
|
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
|
}
|
|
|
|
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
|
// It captures the URL, method, headers, and body. The request body is read and then
|
|
// restored so that it can be processed by subsequent handlers.
|
|
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
|
// Capture URL with sensitive query parameters masked
|
|
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
|
url := c.Request.URL.Path
|
|
if maskedQuery != "" {
|
|
url += "?" + maskedQuery
|
|
}
|
|
|
|
// Capture method
|
|
method := c.Request.Method
|
|
|
|
// Capture headers
|
|
headers := make(map[string][]string)
|
|
for key, values := range c.Request.Header {
|
|
headers[key] = values
|
|
}
|
|
|
|
// Capture request body
|
|
var body []byte
|
|
if captureBody && c.Request.Body != nil {
|
|
// Read the body
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Restore the body for the actual request processing
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
body = decodeCapturedRequestBodyForLog(bodyBytes, c.Request.Header.Get("Content-Encoding"))
|
|
}
|
|
|
|
return &RequestInfo{
|
|
URL: url,
|
|
Method: method,
|
|
Headers: headers,
|
|
Body: body,
|
|
RequestID: logging.GetGinRequestID(c),
|
|
Timestamp: time.Now(),
|
|
}, nil
|
|
}
|
|
|
|
func decodeCapturedRequestBodyForLog(raw []byte, encoding string) []byte {
|
|
if len(raw) == 0 {
|
|
return raw
|
|
}
|
|
|
|
decoded, errDecode := decodeCapturedRequestBody(raw, encoding)
|
|
if errDecode != nil {
|
|
return raw
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func decodeCapturedRequestBody(raw []byte, encoding string) ([]byte, error) {
|
|
encoding = strings.TrimSpace(encoding)
|
|
if encoding == "" || strings.EqualFold(encoding, "identity") {
|
|
return raw, nil
|
|
}
|
|
|
|
parts := strings.Split(encoding, ",")
|
|
body := raw
|
|
for i := len(parts) - 1; i >= 0; i-- {
|
|
enc := strings.ToLower(strings.TrimSpace(parts[i]))
|
|
switch enc {
|
|
case "", "identity":
|
|
continue
|
|
case "zstd":
|
|
decoded, errDecode := decodeCapturedZstdRequestBody(body)
|
|
if errDecode != nil {
|
|
return nil, errDecode
|
|
}
|
|
body = decoded
|
|
default:
|
|
return nil, fmt.Errorf("unsupported request content encoding: %s", enc)
|
|
}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func decodeCapturedZstdRequestBody(raw []byte) ([]byte, error) {
|
|
decoder, errNewReader := zstd.NewReader(bytes.NewReader(raw))
|
|
if errNewReader != nil {
|
|
return nil, fmt.Errorf("failed to create zstd request decoder: %w", errNewReader)
|
|
}
|
|
defer decoder.Close()
|
|
|
|
decoded, errRead := io.ReadAll(decoder)
|
|
if errRead != nil {
|
|
return nil, fmt.Errorf("failed to decode zstd request body: %w", errRead)
|
|
}
|
|
return decoded, nil
|
|
}
|
|
|
|
// shouldLogRequest determines whether the request should be logged.
|
|
// It skips management endpoints to avoid leaking secrets but allows
|
|
// all other routes, including module-provided ones, to honor request-log.
|
|
func shouldLogRequest(path string) bool {
|
|
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
|
return false
|
|
}
|
|
|
|
if strings.HasPrefix(path, "/api") {
|
|
return strings.HasPrefix(path, "/api/provider")
|
|
}
|
|
|
|
return true
|
|
}
|