From 82c9e0de58f91210061bb596ab65b5fb3aff2381 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 16 May 2026 13:00:32 +0800 Subject: [PATCH] feat(api, watcher): add zstd decoding for request logs and payload diff support - Added `zstd` decoding support in request logging, including helper functions to process `Content-Encoding` headers. - Enhanced config diff logic to compare payload-specific rules and track changes in payload configurations. - Added tests to validate `zstd` decoding and payload diff behavior. --- internal/api/middleware/request_logging.go | 56 ++++++++++++++++++- .../api/middleware/request_logging_test.go | 45 +++++++++++++++ internal/watcher/diff/config_diff.go | 26 +++++++++ sdk/cliproxy/service.go | 3 + 4 files changed, 129 insertions(+), 1 deletion(-) diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index 7a10fad8a..4caa0937d 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -5,12 +5,14 @@ package middleware import ( "bytes" + "fmt" "io" "net/http" "strings" "time" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" ) @@ -136,7 +138,7 @@ func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) // Restore the body for the actual request processing c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - body = bodyBytes + body = decodeCapturedRequestBodyForLog(bodyBytes, c.Request.Header.Get("Content-Encoding")) } return &RequestInfo{ @@ -149,6 +151,58 @@ func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) }, nil } +func decodeCapturedRequestBodyForLog(raw []byte, encoding string) []byte { + if len(raw) == 0 { + return raw + } + + decoded, errDecode := decodeCapturedRequestBody(raw, encoding) + if errDecode != nil { + return raw + } + return decoded +} + +func decodeCapturedRequestBody(raw []byte, encoding string) ([]byte, error) { + encoding = strings.TrimSpace(encoding) + if encoding == "" || strings.EqualFold(encoding, "identity") { + return raw, nil + } + + parts := strings.Split(encoding, ",") + body := raw + for i := len(parts) - 1; i >= 0; i-- { + enc := strings.ToLower(strings.TrimSpace(parts[i])) + switch enc { + case "", "identity": + continue + case "zstd": + decoded, errDecode := decodeCapturedZstdRequestBody(body) + if errDecode != nil { + return nil, errDecode + } + body = decoded + default: + return nil, fmt.Errorf("unsupported request content encoding: %s", enc) + } + } + return body, nil +} + +func decodeCapturedZstdRequestBody(raw []byte) ([]byte, error) { + decoder, errNewReader := zstd.NewReader(bytes.NewReader(raw)) + if errNewReader != nil { + return nil, fmt.Errorf("failed to create zstd request decoder: %w", errNewReader) + } + defer decoder.Close() + + decoded, errRead := io.ReadAll(decoder) + if errRead != nil { + return nil, fmt.Errorf("failed to decode zstd request body: %w", errRead) + } + return decoded, nil +} + // shouldLogRequest determines whether the request should be logged. // It skips management endpoints to avoid leaking secrets but allows // all other routes, including module-provided ones, to honor request-log. diff --git a/internal/api/middleware/request_logging_test.go b/internal/api/middleware/request_logging_test.go index c4354678c..732993253 100644 --- a/internal/api/middleware/request_logging_test.go +++ b/internal/api/middleware/request_logging_test.go @@ -1,11 +1,16 @@ package middleware import ( + "bytes" "io" "net/http" + "net/http/httptest" "net/url" "strings" "testing" + + "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" ) func TestShouldSkipMethodForRequestLogging(t *testing.T) { @@ -136,3 +141,43 @@ func TestShouldCaptureRequestBody(t *testing.T) { } } } + +func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) { + gin.SetMode(gin.TestMode) + + payload := []byte(`{"model":"test-model","stream":true}`) + var compressed bytes.Buffer + encoder, errNewWriter := zstd.NewWriter(&compressed) + if errNewWriter != nil { + t.Fatalf("zstd.NewWriter: %v", errNewWriter) + } + if _, errWrite := encoder.Write(payload); errWrite != nil { + t.Fatalf("zstd write: %v", errWrite) + } + if errClose := encoder.Close(); errClose != nil { + t.Fatalf("zstd close: %v", errClose) + } + compressedBytes := compressed.Bytes() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes)) + req.Header.Set("Content-Encoding", "zstd") + c.Request = req + + info, errCapture := captureRequestInfo(c, true) + if errCapture != nil { + t.Fatalf("captureRequestInfo: %v", errCapture) + } + if !bytes.Equal(info.Body, payload) { + t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload)) + } + + restoredBody, errRead := io.ReadAll(c.Request.Body) + if errRead != nil { + t.Fatalf("read restored request body: %v", errRead) + } + if !bytes.Equal(restoredBody, compressedBytes) { + t.Fatal("request body was not restored with the original compressed bytes") + } +} diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index c206049e4..dcfa595f6 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -93,6 +93,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.Routing.Strategy != newCfg.Routing.Strategy { changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy)) } + if !reflect.DeepEqual(oldCfg.Payload, newCfg.Payload) { + changes = appendPayloadConfigChanges(changes, oldCfg.Payload, newCfg.Payload) + } // API keys (redacted) and counts if len(oldCfg.APIKeys) != len(newCfg.APIKeys) { @@ -338,6 +341,29 @@ func trimStrings(in []string) []string { return out } +func appendPayloadConfigChanges(changes []string, oldPayload, newPayload config.PayloadConfig) []string { + changes = appendPayloadRuleChanges(changes, "default", oldPayload.Default, newPayload.Default) + changes = appendPayloadRuleChanges(changes, "default-raw", oldPayload.DefaultRaw, newPayload.DefaultRaw) + changes = appendPayloadRuleChanges(changes, "override", oldPayload.Override, newPayload.Override) + changes = appendPayloadRuleChanges(changes, "override-raw", oldPayload.OverrideRaw, newPayload.OverrideRaw) + changes = appendPayloadFilterRuleChanges(changes, "filter", oldPayload.Filter, newPayload.Filter) + return changes +} + +func appendPayloadRuleChanges(changes []string, section string, oldRules, newRules []config.PayloadRule) []string { + if reflect.DeepEqual(oldRules, newRules) { + return changes + } + return append(changes, fmt.Sprintf("payload.%s: updated (%d -> %d rules)", section, len(oldRules), len(newRules))) +} + +func appendPayloadFilterRuleChanges(changes []string, section string, oldRules, newRules []config.PayloadFilterRule) []string { + if reflect.DeepEqual(oldRules, newRules) { + return changes + } + return append(changes, fmt.Sprintf("payload.%s: updated (%d -> %d rules)", section, len(oldRules), len(newRules))) +} + func equalStringMap(a, b map[string]string) bool { if len(a) != len(b) { return false diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 8685872e0..823daad0b 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -555,6 +555,9 @@ func (s *Service) applyConfigUpdate(newCfg *config.Config) { s.coreManager.SetConfig(newCfg) s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias) } + if newCfg.Home.Enabled { + s.registerHomeExecutors() + } s.rebindExecutors() }