diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index 4caa0937d..561219c4f 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -58,6 +58,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { wrapper.logOnErrorOnly = true } c.Writer = wrapper + attachWebsocketLogSources(c, logger, loggerEnabled) // Process the request c.Next() @@ -70,6 +71,26 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { } } +type fileBodySourceFactory interface { + NewFileBodySource(prefix string) (*logging.FileBodySource, error) +} + +func attachWebsocketLogSources(c *gin.Context, logger logging.RequestLogger, loggerEnabled bool) { + if c == nil || !loggerEnabled || !isResponsesWebsocketUpgrade(c.Request) { + return + } + factory, ok := logger.(fileBodySourceFactory) + if !ok || factory == nil { + return + } + if source, errSource := factory.NewFileBodySource("websocket-timeline"); errSource == nil { + c.Set(logging.WebsocketTimelineSourceContextKey, source) + } + if source, errSource := factory.NewFileBodySource("api-websocket-timeline"); errSource == nil { + c.Set(logging.APIWebsocketTimelineSourceContextKey, source) + } +} + func shouldSkipMethodForRequestLogging(req *http.Request) bool { if req == nil { return true diff --git a/internal/api/middleware/request_logging_test.go b/internal/api/middleware/request_logging_test.go index 732993253..c64b844a8 100644 --- a/internal/api/middleware/request_logging_test.go +++ b/internal/api/middleware/request_logging_test.go @@ -6,11 +6,13 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" + "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" ) func TestShouldSkipMethodForRequestLogging(t *testing.T) { @@ -142,6 +144,63 @@ func TestShouldCaptureRequestBody(t *testing.T) { } } +func TestAttachWebsocketLogSourcesUsesLoggerLogsDir(t *testing.T) { + gin.SetMode(gin.TestMode) + + logsDir := t.TempDir() + logger := logging.NewFileRequestLogger(true, logsDir, "", 0) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/responses", nil) + c.Request.Header.Set("Upgrade", "websocket") + + attachWebsocketLogSources(c, logger, true) + defer cleanupFileBodySourcesFromContext(c) + + for _, key := range []string{ + logging.WebsocketTimelineSourceContextKey, + logging.APIWebsocketTimelineSourceContextKey, + } { + value, exists := c.Get(key) + if !exists { + t.Fatalf("expected %s source to be attached", key) + } + source, ok := value.(*logging.FileBodySource) + if !ok || source == nil { + t.Fatalf("%s source type = %T", key, value) + } + file, errPart := source.CreatePart("probe") + if errPart != nil { + t.Fatalf("CreatePart(%s): %v", key, errPart) + } + path := file.Name() + if errClose := file.Close(); errClose != nil { + t.Fatalf("close part: %v", errClose) + } + if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) { + t.Fatalf("%s part path %s is not under logs dir %s", key, path, logsDir) + } + } +} + +func cleanupFileBodySourcesFromContext(c *gin.Context) { + if c == nil { + return + } + for _, key := range []string{ + logging.WebsocketTimelineSourceContextKey, + logging.APIWebsocketTimelineSourceContextKey, + } { + value, exists := c.Get(key) + if !exists { + continue + } + if source, ok := value.(*logging.FileBodySource); ok && source != nil { + _ = source.Cleanup() + } + } +} + func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index 5a89ed0fd..4d4960054 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -280,7 +280,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled() + websocketTimelineSource := w.extractWebsocketTimelineSource(c) + apiWebsocketTimelineSource := w.extractAPIWebsocketTimelineSource(c) if !w.logger.IsEnabled() && !forceLog { + cleanupFileBodySources(websocketTimelineSource, apiWebsocketTimelineSource) return nil } @@ -307,6 +310,13 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { _ = w.streamWriter.WriteAPIResponse(apiResponse) } apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c) + var errMerge error + apiWebsocketTimeline, errMerge = mergeFileBodySource(apiWebsocketTimeline, apiWebsocketTimelineSource) + if errMerge != nil { + cleanupFileBodySources(websocketTimelineSource) + return errMerge + } + cleanupFileBodySources(websocketTimelineSource) if len(apiWebsocketTimeline) > 0 { _ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline) } @@ -318,7 +328,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { return nil } - return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) + return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), websocketTimelineSource, w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), apiWebsocketTimelineSource, w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) } func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { @@ -370,6 +380,10 @@ func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []by return bytes.Clone(data) } +func (w *ResponseWriterWrapper) extractAPIWebsocketTimelineSource(c *gin.Context) *logging.FileBodySource { + return extractFileBodySource(c, logging.APIWebsocketTimelineSourceContextKey) +} + func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time { ts, isExist := c.Get("API_RESPONSE_TIMESTAMP") if !isExist { @@ -405,6 +419,25 @@ func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte return extractBodyOverride(c, websocketTimelineOverrideContextKey) } +func (w *ResponseWriterWrapper) extractWebsocketTimelineSource(c *gin.Context) *logging.FileBodySource { + return extractFileBodySource(c, logging.WebsocketTimelineSourceContextKey) +} + +func extractFileBodySource(c *gin.Context, key string) *logging.FileBodySource { + if c == nil { + return nil + } + value, exists := c.Get(key) + if !exists { + return nil + } + source, ok := value.(*logging.FileBodySource) + if !ok || source == nil { + return nil + } + return source +} + func extractBodyOverride(c *gin.Context, key string) []byte { if c == nil { return nil @@ -426,11 +459,48 @@ func extractBodyOverride(c *gin.Context, key string) []byte { return nil } -func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { +func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline []byte, websocketTimelineSource *logging.FileBodySource, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiWebsocketTimelineSource *logging.FileBodySource, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { if w.requestInfo == nil { + cleanupFileBodySources(websocketTimelineSource, apiWebsocketTimelineSource) return nil } + if loggerWithSources, ok := w.logger.(interface { + LogRequestWithOptionsAndSources(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, *logging.FileBodySource, []byte, []byte, []byte, *logging.FileBodySource, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error + }); ok { + return loggerWithSources.LogRequestWithOptionsAndSources( + w.requestInfo.URL, + w.requestInfo.Method, + w.requestInfo.Headers, + requestBody, + statusCode, + headers, + body, + websocketTimeline, + websocketTimelineSource, + apiRequestBody, + apiResponseBody, + apiWebsocketTimeline, + apiWebsocketTimelineSource, + apiResponseErrors, + forceLog, + w.requestInfo.RequestID, + w.requestInfo.Timestamp, + apiResponseTimestamp, + ) + } + + var errMerge error + websocketTimeline, errMerge = mergeFileBodySource(websocketTimeline, websocketTimelineSource) + if errMerge != nil { + cleanupFileBodySources(apiWebsocketTimelineSource) + return errMerge + } + apiWebsocketTimeline, errMerge = mergeFileBodySource(apiWebsocketTimeline, apiWebsocketTimelineSource) + if errMerge != nil { + return errMerge + } + if loggerWithOptions, ok := w.logger.(interface { LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error }); ok { @@ -472,3 +542,34 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h apiResponseTimestamp, ) } + +func mergeFileBodySource(payload []byte, source *logging.FileBodySource) ([]byte, error) { + if source == nil { + return payload, nil + } + defer cleanupFileBodySources(source) + if !source.HasPayload() { + return payload, nil + } + var buf bytes.Buffer + if len(payload) > 0 { + buf.Write(payload) + if !bytes.HasSuffix(payload, []byte("\n")) { + buf.WriteByte('\n') + } + buf.WriteByte('\n') + } + if errWrite := source.WriteTo(&buf); errWrite != nil { + return nil, errWrite + } + return buf.Bytes(), nil +} + +func cleanupFileBodySources(sources ...*logging.FileBodySource) { + for _, source := range sources { + if source == nil { + continue + } + _ = source.Cleanup() + } +} diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index 26b2f42b3..8a8b6fbde 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -17,6 +17,7 @@ import ( "regexp" "sort" "strings" + "sync" "sync/atomic" "time" @@ -32,6 +33,11 @@ import ( var requestLogID atomic.Uint64 +const ( + WebsocketTimelineSourceContextKey = "WEBSOCKET_TIMELINE_SOURCE" + APIWebsocketTimelineSourceContextKey = "API_WEBSOCKET_TIMELINE_SOURCE" +) + type homeRequestLogClient interface { HeartbeatOK() bool RPushRequestLog(ctx context.Context, payload []byte) error @@ -41,6 +47,199 @@ var currentHomeRequestLogClient = func() homeRequestLogClient { return home.Current() } +// FileBodySource stores large log sections as ordered temp-file parts. +type FileBodySource struct { + mu sync.Mutex + dir string + paths []string + cleaned bool +} + +// NewFileBodySourceInDir creates a temp-backed source under baseDir. +func NewFileBodySourceInDir(baseDir string, prefix string) (*FileBodySource, error) { + prefix = sanitizeTempPrefix(prefix) + baseDir = strings.TrimSpace(baseDir) + if baseDir == "" { + return nil, fmt.Errorf("base directory is required") + } + if errMkdir := os.MkdirAll(baseDir, 0755); errMkdir != nil { + return nil, errMkdir + } + dir, errCreate := os.MkdirTemp(baseDir, "request-log-parts-"+prefix+"-*") + if errCreate != nil { + return nil, errCreate + } + return &FileBodySource{dir: dir}, nil +} + +func sanitizeTempPrefix(prefix string) string { + prefix = strings.TrimSpace(prefix) + if prefix == "" { + return "log" + } + var builder strings.Builder + for _, r := range prefix { + switch { + case r >= 'a' && r <= 'z': + builder.WriteRune(r) + case r >= 'A' && r <= 'Z': + builder.WriteRune(r) + case r >= '0' && r <= '9': + builder.WriteRune(r) + case r == '-' || r == '_': + builder.WriteRune(r) + default: + builder.WriteByte('-') + } + } + out := strings.Trim(builder.String(), "-_") + if out == "" { + return "log" + } + return out +} + +// CreatePart creates one ordered detail log part. +func (s *FileBodySource) CreatePart(prefix string) (*os.File, error) { + if s == nil { + return nil, fmt.Errorf("file body source is nil") + } + s.mu.Lock() + defer s.mu.Unlock() + if s.cleaned { + return nil, fmt.Errorf("file body source has been cleaned") + } + prefix = sanitizeTempPrefix(prefix) + file, errCreate := os.CreateTemp(s.dir, prefix+"-*.tmp") + if errCreate != nil { + return nil, errCreate + } + s.paths = append(s.paths, file.Name()) + return file, nil +} + +// AppendPart appends one complete ordered part to the source. +func (s *FileBodySource) AppendPart(data []byte) error { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil + } + file, errCreate := s.CreatePart("part") + if errCreate != nil { + return errCreate + } + writeErr := writeLogPart(file, data, false) + if errClose := file.Close(); errClose != nil { + if writeErr == nil { + writeErr = errClose + } + } + return writeErr +} + +// HasPayload reports whether any detail parts were recorded. +func (s *FileBodySource) HasPayload() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return len(s.paths) > 0 && !s.cleaned +} + +// Paths returns a copy of the ordered part paths. +func (s *FileBodySource) Paths() []string { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + out := make([]string, len(s.paths)) + copy(out, s.paths) + return out +} + +// WriteTo merges all ordered parts into w. +func (s *FileBodySource) WriteTo(w io.Writer) error { + if s == nil || w == nil { + return nil + } + paths := s.Paths() + for i, path := range paths { + if i > 0 { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + file, errOpen := os.Open(path) + if errOpen != nil { + return errOpen + } + _, errCopy := io.Copy(w, file) + if errClose := file.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close log part file") + if errCopy == nil { + errCopy = errClose + } + } + if errCopy != nil { + return errCopy + } + } + return nil +} + +// Bytes merges all ordered parts into memory. +func (s *FileBodySource) Bytes() ([]byte, error) { + var buf bytes.Buffer + if errWrite := s.WriteTo(&buf); errWrite != nil { + return nil, errWrite + } + return buf.Bytes(), nil +} + +// Cleanup removes all temp detail parts and their directory. +func (s *FileBodySource) Cleanup() error { + if s == nil { + return nil + } + s.mu.Lock() + if s.cleaned { + s.mu.Unlock() + return nil + } + paths := make([]string, len(s.paths)) + copy(paths, s.paths) + dir := s.dir + s.paths = nil + s.cleaned = true + s.mu.Unlock() + + var firstErr error + for _, path := range paths { + if errRemove := os.Remove(path); errRemove != nil && !os.IsNotExist(errRemove) && firstErr == nil { + firstErr = errRemove + } + } + if dir != "" { + if errRemove := os.Remove(dir); errRemove != nil && !os.IsNotExist(errRemove) && firstErr == nil { + firstErr = errRemove + } + } + return firstErr +} + +func cleanupFileBodySources(sources ...*FileBodySource) { + for _, source := range sources { + if source == nil { + continue + } + if errCleanup := source.Cleanup(); errCleanup != nil { + log.WithError(errCleanup).Warn("failed to clean up log part files") + } + } +} + // RequestLogger defines the interface for logging HTTP requests and responses. // It provides methods for logging both regular and streaming HTTP request/response cycles. type RequestLogger interface { @@ -274,6 +473,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) { l.errorLogsMaxFiles = maxFiles } +// NewFileBodySource creates a temp-backed source under the request log directory. +func (l *FileRequestLogger) NewFileBodySource(prefix string) (*FileBodySource, error) { + if l == nil { + return nil, fmt.Errorf("file request logger is nil") + } + if errEnsure := l.ensureLogsDir(); errEnsure != nil { + return nil, errEnsure + } + return NewFileBodySourceInDir(l.logsDir, prefix) +} + // LogRequest logs a complete non-streaming request/response cycle to a file. // // Parameters: @@ -299,10 +509,21 @@ func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[st // LogRequestWithOptions logs a request with optional forced logging behavior. // The force flag allows writing error logs even when regular request logging is disabled. func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { - return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) + return l.logRequestWithSources(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, nil, apiRequest, apiResponse, apiWebsocketTimeline, nil, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) } func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { + return l.logRequestWithSources(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, nil, apiRequest, apiResponse, apiWebsocketTimeline, nil, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) +} + +// LogRequestWithOptionsAndSources logs a request with optional file-backed large sections. +func (l *FileRequestLogger) LogRequestWithOptionsAndSources(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline []byte, websocketTimelineSource *FileBodySource, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiWebsocketTimelineSource *FileBodySource, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { + return l.logRequestWithSources(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, websocketTimelineSource, apiRequest, apiResponse, apiWebsocketTimeline, apiWebsocketTimelineSource, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) +} + +func (l *FileRequestLogger) logRequestWithSources(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline []byte, websocketTimelineSource *FileBodySource, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiWebsocketTimelineSource *FileBodySource, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { + defer cleanupFileBodySources(websocketTimelineSource, apiWebsocketTimelineSource) + if !l.enabled && !force { return nil } @@ -322,9 +543,11 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st body, "", websocketTimeline, + websocketTimelineSource, apiRequest, apiResponse, apiWebsocketTimeline, + apiWebsocketTimelineSource, apiResponseErrors, statusCode, responseHeaders, @@ -382,9 +605,11 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st body, requestBodyPath, websocketTimeline, + websocketTimelineSource, apiRequest, apiResponse, apiWebsocketTimeline, + apiWebsocketTimelineSource, apiResponseErrors, statusCode, responseHeaders, @@ -430,7 +655,7 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ } if l.homeEnabled { - client := home.Current() + client := currentHomeRequestLogClient() if client == nil || !client.HeartbeatOK() { return &NoOpStreamingLogWriter{}, nil } @@ -650,9 +875,11 @@ func (l *FileRequestLogger) writeNonStreamingLog( requestBody []byte, requestBodyPath string, websocketTimeline []byte, + websocketTimelineSource *FileBodySource, apiRequest []byte, apiResponse []byte, apiWebsocketTimeline []byte, + apiWebsocketTimelineSource *FileBodySource, apiResponseErrors []*interfaces.ErrorMessage, statusCode int, responseHeaders map[string][]string, @@ -664,16 +891,16 @@ func (l *FileRequestLogger) writeNonStreamingLog( if requestTimestamp.IsZero() { requestTimestamp = time.Now() } - isWebsocketTranscript := hasSectionPayload(websocketTimeline) - downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline) - upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors) + isWebsocketTranscript := hasSectionPayload(websocketTimeline) || hasFileBodySourcePayload(websocketTimelineSource) + downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline, websocketTimelineSource) + upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiWebsocketTimelineSource, apiResponseErrors) if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil { return errWrite } - if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil { + if errWrite := writeAPISectionWithSource(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, websocketTimelineSource, time.Time{}); errWrite != nil { return errWrite } - if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil { + if errWrite := writeAPISectionWithSource(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, apiWebsocketTimelineSource, time.Time{}); errWrite != nil { return errWrite } if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil { @@ -829,8 +1056,12 @@ func hasSectionPayload(payload []byte) bool { return len(bytes.TrimSpace(payload)) > 0 } -func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string { - if hasSectionPayload(websocketTimeline) { +func hasFileBodySourcePayload(source *FileBodySource) bool { + return source != nil && source.HasPayload() +} + +func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte, websocketTimelineSource *FileBodySource) string { + if hasSectionPayload(websocketTimeline) || hasFileBodySourcePayload(websocketTimelineSource) { return "websocket" } for key, values := range headers { @@ -845,9 +1076,9 @@ func inferDownstreamTransport(headers map[string][]string, websocketTimeline []b return "http" } -func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string { +func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, apiWebsocketTimelineSource *FileBodySource, _ []*interfaces.ErrorMessage) string { hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse) - hasWS := hasSectionPayload(apiWebsocketTimeline) + hasWS := hasSectionPayload(apiWebsocketTimeline) || hasFileBodySourcePayload(apiWebsocketTimelineSource) switch { case hasHTTP && hasWS: return "websocket+http" @@ -860,6 +1091,26 @@ func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte } } +func writeLogPart(w io.Writer, payload []byte, prependNewline bool) error { + if w == nil { + return nil + } + if prependNewline { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + if _, errWrite := w.Write(payload); errWrite != nil { + return errWrite + } + if !bytes.HasSuffix(payload, []byte("\n")) { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + return nil +} + func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error { if len(payload) == 0 { return nil @@ -889,6 +1140,33 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa return nil } +func writeAPISectionWithSource(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, source *FileBodySource, timestamp time.Time) error { + if !hasFileBodySourcePayload(source) { + return writeAPISection(w, sectionHeader, sectionPrefix, payload, timestamp) + } + if len(payload) > 0 { + if errWrite := writeAPISection(w, sectionHeader, sectionPrefix, payload, timestamp); errWrite != nil { + return errWrite + } + } + if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil { + return errWrite + } + if !timestamp.IsZero() { + if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil { + return errWrite + } + } + tracker := &trailingNewlineTrackingWriter{writer: w} + if errWrite := source.WriteTo(tracker); errWrite != nil { + return errWrite + } + if errWrite := writeSectionSpacing(w, tracker.trailingNewlines); errWrite != nil { + return errWrite + } + return nil +} + func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error { for i := 0; i < len(apiResponseErrors); i++ { if apiResponseErrors[i] == nil { @@ -998,8 +1276,8 @@ func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool { func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { var content strings.Builder isWebsocketTranscript := hasSectionPayload(websocketTimeline) - downstreamTransport := inferDownstreamTransport(headers, websocketTimeline) - upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors) + downstreamTransport := inferDownstreamTransport(headers, websocketTimeline, nil) + upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, nil, apiResponseErrors) // Request info content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript)) @@ -1510,7 +1788,7 @@ func (w *FileStreamingLogWriter) asyncWriter() { } func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error { - if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil { + if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil, nil), true); errWrite != nil { return errWrite } if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil { @@ -1751,7 +2029,7 @@ func (w *homeStreamingLogWriter) Close() error { responsePayload := w.responseBody.Bytes() var buf bytes.Buffer - upstreamTransport := inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTime, nil) + upstreamTransport := inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTime, nil, nil) if errWrite := writeRequestInfoWithBody(&buf, w.url, w.method, w.requestHeaders, w.requestBody, "", w.timestamp, "http", upstreamTransport, true); errWrite != nil { return errWrite } diff --git a/internal/logging/request_logger_home_test.go b/internal/logging/request_logger_home_test.go index 4f66cacec..2d974f31d 100644 --- a/internal/logging/request_logger_home_test.go +++ b/internal/logging/request_logger_home_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "net/http" "os" + "strings" "testing" "time" ) @@ -97,6 +98,160 @@ func TestFileRequestLogger_HomeEnabled_ForwardsWhenRequestLogEnabled(t *testing. } } +func TestFileRequestLogger_LogRequestWithSourcesWritesLocalLogAndCleansParts(t *testing.T) { + logsDir := t.TempDir() + logger := NewFileRequestLogger(true, logsDir, "", 0) + + timelineSource, errSource := logger.NewFileBodySource("websocket-timeline-test") + if errSource != nil { + t.Fatalf("logger.NewFileBodySource: %v", errSource) + } + if errAppend := timelineSource.AppendPart([]byte("Timestamp: 2026-05-25T12:00:00Z\nEvent: websocket.request\n{}")); errAppend != nil { + t.Fatalf("AppendPart request: %v", errAppend) + } + if errAppend := timelineSource.AppendPart([]byte("Timestamp: 2026-05-25T12:00:01Z\nEvent: websocket.response\n{}")); errAppend != nil { + t.Fatalf("AppendPart response: %v", errAppend) + } + partPaths := timelineSource.Paths() + for _, path := range partPaths { + if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) { + t.Fatalf("part path %s is not under logs dir %s", path, logsDir) + } + } + + errLog := logger.LogRequestWithOptionsAndSources( + "/v1/responses/ws", + http.MethodGet, + map[string][]string{"Upgrade": {"websocket"}}, + nil, + http.StatusSwitchingProtocols, + map[string][]string{"Upgrade": {"websocket"}}, + nil, + nil, + timelineSource, + nil, + nil, + nil, + nil, + nil, + false, + "ws-req-1", + time.Now(), + time.Now(), + ) + if errLog != nil { + t.Fatalf("LogRequestWithOptionsAndSources error: %v", errLog) + } + + for _, path := range partPaths { + if _, errStat := os.Stat(path); !os.IsNotExist(errStat) { + t.Fatalf("expected part %s to be removed, stat err=%v", path, errStat) + } + } + + entries, errRead := os.ReadDir(logsDir) + if errRead != nil { + t.Fatalf("failed to read logs dir: %v", errRead) + } + var logPath string + for _, entry := range entries { + if entry.IsDir() { + continue + } + logPath = logsDir + string(os.PathSeparator) + entry.Name() + break + } + if logPath == "" { + t.Fatal("expected local request log file") + } + raw, errReadLog := os.ReadFile(logPath) + if errReadLog != nil { + t.Fatalf("read log file: %v", errReadLog) + } + if !bytes.Contains(raw, []byte("=== WEBSOCKET TIMELINE ===")) { + t.Fatalf("websocket timeline section missing: %s", string(raw)) + } + if !bytes.Contains(raw, []byte("Event: websocket.request")) || !bytes.Contains(raw, []byte("Event: websocket.response")) { + t.Fatalf("merged websocket events missing: %s", string(raw)) + } +} + +func TestFileRequestLogger_HomeEnabled_ForwardsSourceLogAndCleansParts(t *testing.T) { + original := currentHomeRequestLogClient + defer func() { + currentHomeRequestLogClient = original + }() + + stub := &stubHomeRequestLogClient{heartbeatOK: true} + currentHomeRequestLogClient = func() homeRequestLogClient { + return stub + } + + logsDir := t.TempDir() + logger := NewFileRequestLogger(true, logsDir, "", 0) + logger.SetHomeEnabled(true) + + timelineSource, errSource := logger.NewFileBodySource("home-websocket-timeline-test") + if errSource != nil { + t.Fatalf("logger.NewFileBodySource: %v", errSource) + } + if errAppend := timelineSource.AppendPart([]byte("Timestamp: 2026-05-25T12:00:00Z\nEvent: websocket.request\n{}")); errAppend != nil { + t.Fatalf("AppendPart request: %v", errAppend) + } + partPaths := timelineSource.Paths() + for _, path := range partPaths { + if !strings.HasPrefix(path, logsDir+string(os.PathSeparator)) { + t.Fatalf("part path %s is not under logs dir %s", path, logsDir) + } + } + + errLog := logger.LogRequestWithOptionsAndSources( + "/v1/responses/ws", + http.MethodGet, + map[string][]string{"Upgrade": {"websocket"}}, + nil, + http.StatusSwitchingProtocols, + map[string][]string{"Upgrade": {"websocket"}}, + nil, + nil, + timelineSource, + nil, + nil, + nil, + nil, + nil, + false, + "home-ws-req-1", + time.Now(), + time.Now(), + ) + if errLog != nil { + t.Fatalf("LogRequestWithOptionsAndSources error: %v", errLog) + } + if len(stub.pushed) != 1 { + t.Fatalf("home pushed records = %d, want 1", len(stub.pushed)) + } + + var got struct { + RequestID string `json:"request_id"` + RequestLog string `json:"request_log"` + } + if errUnmarshal := json.Unmarshal(stub.pushed[0], &got); errUnmarshal != nil { + t.Fatalf("unmarshal payload: %v payload=%s", errUnmarshal, string(stub.pushed[0])) + } + if got.RequestID != "home-ws-req-1" { + t.Fatalf("request_id = %q, want home-ws-req-1", got.RequestID) + } + if !strings.Contains(got.RequestLog, "Event: websocket.request") { + t.Fatalf("forwarded request_log missing websocket request: %s", got.RequestLog) + } + for _, path := range partPaths { + if _, errStat := os.Stat(path); !os.IsNotExist(errStat) { + t.Fatalf("expected part %s to be removed, stat err=%v", path, errStat) + } + } +} + func TestFileRequestLogger_HomeEnabled_ForwardsStreamingRequestID(t *testing.T) { original := currentHomeRequestLogClient defer func() { diff --git a/internal/runtime/executor/helps/logging_helpers.go b/internal/runtime/executor/helps/logging_helpers.go index 87fc7ac34..c32230585 100644 --- a/internal/runtime/executor/helps/logging_helpers.go +++ b/internal/runtime/executor/helps/logging_helpers.go @@ -416,6 +416,13 @@ func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) { if len(data) == 0 { return } + if source, ok := apiWebsocketTimelineSource(ginCtx); ok { + if errAppend := source.AppendPart(data); errAppend == nil { + return + } else { + log.WithError(errAppend).Warn("failed to append api websocket timeline log part") + } + } if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists { if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { combined := make([]byte, 0, len(existingBytes)+len(data)+2) @@ -432,6 +439,18 @@ func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) { ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data)) } +func apiWebsocketTimelineSource(ginCtx *gin.Context) (*logging.FileBodySource, bool) { + if ginCtx == nil { + return nil, false + } + value, exists := ginCtx.Get(logging.APIWebsocketTimelineSourceContextKey) + if !exists { + return nil, false + } + source, ok := value.(*logging.FileBodySource) + return source, ok && source != nil +} + func markAPIResponseTimestamp(ginCtx *gin.Context) { if ginCtx == nil { return diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 574338fd7..eae042b9e 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strconv" "strings" @@ -14,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + requestlogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" @@ -43,6 +45,166 @@ var responsesWebsocketUpgrader = websocket.Upgrader{ }, } +type websocketTimelineAppender interface { + Append(eventType string, payload []byte, timestamp time.Time) +} + +type websocketTimelineLog struct { + enabled bool + source *requestlogging.FileBodySource + builder *strings.Builder + + currentPart io.WriteCloser + currentPartHasLog bool +} + +func newWebsocketTimelineLog(enabled bool, source *requestlogging.FileBodySource) *websocketTimelineLog { + if !enabled { + return &websocketTimelineLog{} + } + if source == nil { + return newInMemoryWebsocketTimelineLog() + } + return &websocketTimelineLog{ + enabled: true, + source: source, + } +} + +func newInMemoryWebsocketTimelineLog() *websocketTimelineLog { + return &websocketTimelineLog{ + enabled: true, + builder: &strings.Builder{}, + } +} + +func websocketTimelineSourceFromContext(c *gin.Context) *requestlogging.FileBodySource { + if c == nil { + return nil + } + value, exists := c.Get(requestlogging.WebsocketTimelineSourceContextKey) + if !exists { + return nil + } + source, ok := value.(*requestlogging.FileBodySource) + if !ok { + return nil + } + return source +} + +func (l *websocketTimelineLog) BeginRequest() { + if l == nil || !l.enabled || l.source == nil { + return + } + l.closeCurrentPart() + part, errCreate := l.source.CreatePart("request") + if errCreate != nil { + log.WithError(errCreate).Warn("failed to create websocket request detail log") + return + } + l.currentPart = part + l.currentPartHasLog = false +} + +func (l *websocketTimelineLog) Append(eventType string, payload []byte, timestamp time.Time) { + if l == nil || !l.enabled { + return + } + data := formatWebsocketTimelineEvent(eventType, payload, timestamp) + if len(data) == 0 { + return + } + if l.source != nil { + if l.currentPart == nil { + l.BeginRequest() + } + if l.currentPart == nil { + return + } + if errWrite := writeWebsocketTimelinePart(l.currentPart, data, l.currentPartHasLog); errWrite != nil { + log.WithError(errWrite).Warn("failed to write websocket request detail log") + return + } + l.currentPartHasLog = true + return + } + if l.builder != nil { + writeWebsocketTimelineBuilder(l.builder, data) + } +} + +func (l *websocketTimelineLog) SetContext(c *gin.Context) { + if l == nil || !l.enabled { + return + } + l.closeCurrentPart() + if l.source != nil { + if l.source.HasPayload() { + c.Set(requestlogging.WebsocketTimelineSourceContextKey, l.source) + return + } + if errCleanup := l.source.Cleanup(); errCleanup != nil { + log.WithError(errCleanup).Warn("failed to clean up empty websocket timeline log parts") + } + } + if l.builder != nil { + setWebsocketTimelineBody(c, l.builder.String()) + } +} + +func (l *websocketTimelineLog) String() string { + if l == nil || !l.enabled { + return "" + } + l.closeCurrentPart() + if l.source != nil { + data, errRead := l.source.Bytes() + if errRead != nil { + return "" + } + return string(data) + } + if l.builder == nil { + return "" + } + return l.builder.String() +} + +func (l *websocketTimelineLog) closeCurrentPart() { + if l == nil || l.currentPart == nil { + return + } + if errClose := l.currentPart.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close websocket request detail log") + } + l.currentPart = nil + l.currentPartHasLog = false +} + +func writeWebsocketTimelinePart(w io.Writer, data []byte, prependNewline bool) error { + if w == nil || len(data) == 0 { + return nil + } + if prependNewline { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + _, errWrite := w.Write(data) + return errWrite +} + +func writeWebsocketTimelineBuilder(builder *strings.Builder, data []byte) { + if builder == nil || len(data) == 0 { + return + } + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.Write(data) +} + // ResponsesWebsocket handles websocket requests for /v1/responses. // It accepts `response.create` and `response.append` requests and streams // response events back as JSON websocket text messages. @@ -57,6 +219,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { clientIP := websocketClientAddress(c) log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP) + requestLogEnabled := h != nil && h.Cfg != nil && h.Cfg.RequestLog + wsTimelineLog := newWebsocketTimelineLog(requestLogEnabled, websocketTimelineSourceFromContext(c)) + wsDone := make(chan struct{}) defer close(wsDone) @@ -82,11 +247,10 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } var wsTerminateErr error - var wsTimelineLog strings.Builder defer func() { releaseResponsesWebsocketToolCaches(downstreamSessionKey) if wsTerminateErr != nil { - appendWebsocketTimelineDisconnect(&wsTimelineLog, wsTerminateErr, time.Now()) + appendWebsocketTimelineDisconnect(wsTimelineLog, wsTerminateErr, time.Now()) // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) } else { log.Infof("responses websocket: session closing id=%s", passthroughSessionID) @@ -95,7 +259,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { h.AuthManager.CloseExecutionSession(passthroughSessionID) log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID) } - setWebsocketTimelineBody(c, wsTimelineLog.String()) + wsTimelineLog.SetContext(c) if errClose := conn.Close(); errClose != nil { log.Warnf("responses websocket: close connection error: %v", errClose) } @@ -136,7 +300,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { // websocketPayloadEventType(payload), // websocketPayloadPreview(payload), // ) - appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now()) + wsTimelineLog.BeginRequest() + wsTimelineLog.Append("request", payload, time.Now()) allowIncrementalInputWithPreviousResponseID := false if pinnedAuthID != "" { @@ -180,7 +345,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) - errorPayload, errWrite := writeResponsesWebsocketError(conn, &wsTimelineLog, errMsg) + errorPayload, errWrite := writeResponsesWebsocketError(conn, wsTimelineLog, errMsg) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", passthroughSessionID, @@ -208,7 +373,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } lastRequest = updatedLastRequest lastResponseOutput = []byte("[]") - if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsTimelineLog, passthroughSessionID); errWrite != nil { + if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, wsTimelineLog, passthroughSessionID); errWrite != nil { wsTerminateErr = errWrite return } @@ -248,7 +413,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID) + completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, wsTimelineLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) @@ -708,7 +873,7 @@ func writeResponsesWebsocketSyntheticPrewarm( c *gin.Context, conn *websocket.Conn, requestJSON []byte, - wsTimelineLog *strings.Builder, + wsTimelineLog websocketTimelineAppender, sessionID string, ) error { payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON) @@ -859,7 +1024,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( cancel handlers.APIHandlerCancelFunc, data <-chan []byte, errs <-chan *interfaces.ErrorMessage, - wsTimelineLog *strings.Builder, + wsTimelineLog websocketTimelineAppender, sessionID string, ) ([]byte, *interfaces.ErrorMessage, error) { completed := false @@ -1031,7 +1196,7 @@ func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte { return payloads } -func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog *strings.Builder, errMsg *interfaces.ErrorMessage) ([]byte, error) { +func writeResponsesWebsocketError(conn *websocket.Conn, wsTimelineLog websocketTimelineAppender, errMsg *interfaces.ErrorMessage) ([]byte, error) { status := http.StatusInternalServerError errText := http.StatusText(status) if errMsg != nil { @@ -1155,29 +1320,35 @@ func setWebsocketBody(c *gin.Context, key string, body string) { c.Set(key, []byte(trimmedBody)) } -func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog *strings.Builder, payload []byte, timestamp time.Time) error { - appendWebsocketTimelineEvent(wsTimelineLog, "response", payload, timestamp) +func writeResponsesWebsocketPayload(conn *websocket.Conn, wsTimelineLog websocketTimelineAppender, payload []byte, timestamp time.Time) error { + if wsTimelineLog != nil { + wsTimelineLog.Append("response", payload, timestamp) + } return conn.WriteMessage(websocket.TextMessage, payload) } -func appendWebsocketTimelineDisconnect(builder *strings.Builder, err error, timestamp time.Time) { +func appendWebsocketTimelineDisconnect(timeline websocketTimelineAppender, err error, timestamp time.Time) { if err == nil { return } - appendWebsocketTimelineEvent(builder, "disconnect", []byte(err.Error()), timestamp) + if timeline != nil { + timeline.Append("disconnect", []byte(err.Error()), timestamp) + } } func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, payload []byte, timestamp time.Time) { if builder == nil { return } + writeWebsocketTimelineBuilder(builder, formatWebsocketTimelineEvent(eventType, payload, timestamp)) +} + +func formatWebsocketTimelineEvent(eventType string, payload []byte, timestamp time.Time) []byte { trimmedPayload := bytes.TrimSpace(payload) if len(trimmedPayload) == 0 { - return - } - if builder.Len() > 0 { - builder.WriteString("\n") + return nil } + var builder strings.Builder builder.WriteString("Timestamp: ") builder.WriteString(timestamp.Format(time.RFC3339Nano)) builder.WriteString("\n") @@ -1186,6 +1357,7 @@ func appendWebsocketTimelineEvent(builder *strings.Builder, eventType string, pa builder.WriteString("\n") builder.Write(trimmedPayload) builder.WriteString("\n") + return []byte(builder.String()) } func markAPIResponseTimestamp(c *gin.Context) { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 7ff58fa3c..8b945b50c 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -15,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + requestlogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" @@ -593,6 +594,34 @@ func TestSetWebsocketTimelineBody(t *testing.T) { } } +func TestWebsocketTimelineLogFallsBackToMemoryWithoutSource(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + ts := time.Date(2026, time.April, 1, 12, 34, 56, 789000000, time.UTC) + + timelineLog := newWebsocketTimelineLog(true, nil) + timelineLog.BeginRequest() + timelineLog.Append("request", []byte(`{"type":"response.create"}`), ts) + timelineLog.SetContext(c) + + value, exists := c.Get(wsTimelineBodyKey) + if !exists { + t.Fatalf("timeline body key not set") + } + bodyBytes, ok := value.([]byte) + if !ok { + t.Fatalf("timeline body key type mismatch") + } + got := string(bodyBytes) + if !strings.Contains(got, "Event: websocket.request") { + t.Fatalf("timeline event not found: %s", got) + } + if !strings.Contains(got, `{"type":"response.create"}`) { + t.Fatalf("timeline payload not found: %s", got) + } +} + func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) { cache := newWebsocketToolOutputCache(time.Minute, 10) sessionKey := "session-1" @@ -867,14 +896,14 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(data) close(errCh) - var timelineLog strings.Builder + timelineLog := newInMemoryWebsocketTimelineLog() completedOutput, errMsg, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, data, errCh, - &timelineLog, + timelineLog, "session-1", ) if err != nil { @@ -945,7 +974,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing close(data) close(errCh) - var timelineLog strings.Builder + timelineLog := newInMemoryWebsocketTimelineLog() if errClose := conn.Close(); errClose != nil { serverErrCh <- errClose return @@ -957,7 +986,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing func(...interface{}) {}, data, errCh, - &timelineLog, + timelineLog, "session-1", ) if err == nil { @@ -994,18 +1023,36 @@ func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) { gin.SetMode(gin.TestMode) manager := coreauth.NewManager(nil, nil, nil) - base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{RequestLog: true}, manager) h := NewOpenAIResponsesAPIHandler(base) + logsDir := t.TempDir() timelineCh := make(chan string, 1) router := gin.New() router.GET("/v1/responses/ws", func(c *gin.Context) { + source, errSource := requestlogging.NewFileBodySourceInDir(logsDir, "websocket-timeline-test") + if errSource != nil { + timelineCh <- "" + return + } + c.Set(requestlogging.WebsocketTimelineSourceContextKey, source) h.ResponsesWebsocket(c) timeline := "" if value, exists := c.Get(wsTimelineBodyKey); exists { if body, ok := value.([]byte); ok { timeline = string(body) } + } else if value, exists := c.Get(requestlogging.WebsocketTimelineSourceContextKey); exists { + if source, ok := value.(*requestlogging.FileBodySource); ok { + body, _ := source.Bytes() + timeline = string(body) + _ = source.Cleanup() + } + } + if value, exists := c.Get(requestlogging.APIWebsocketTimelineSourceContextKey); exists { + if source, ok := value.(*requestlogging.FileBodySource); ok { + _ = source.Cleanup() + } } timelineCh <- timeline })