From acf98ed10e9bcf39c332bf79098f8a4d87d4d1d8 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 1 Apr 2026 17:28:50 +0800 Subject: [PATCH] fix(openai): add session reference counter and cache lifecycle management for websocket tools --- .../openai/openai_responses_websocket.go | 2 + ...nai_responses_websocket_toolcall_repair.go | 87 +++++++++++++++++-- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index b8076601..6c43e931 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -55,6 +55,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } passthroughSessionID := uuid.NewString() downstreamSessionKey := websocketDownstreamSessionKey(c.Request) + retainResponsesWebsocketToolCaches(downstreamSessionKey) clientRemoteAddr := "" if c != nil && c.Request != nil { clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr) @@ -63,6 +64,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var wsTerminateErr error var wsBodyLog strings.Builder defer func() { + releaseResponsesWebsocketToolCaches(downstreamSessionKey) if wsTerminateErr != nil { // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) } else { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go index 8333bce6..530aca96 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go @@ -16,8 +16,9 @@ const ( websocketToolOutputCacheTTL = 30 * time.Minute ) -var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession) -var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession) +var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession) +var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession) +var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter() type websocketToolOutputCache struct { mu sync.Mutex @@ -33,7 +34,7 @@ type websocketToolOutputSession struct { } func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache { - if ttl <= 0 { + if ttl < 0 { ttl = websocketToolOutputCacheTTL } if maxPerSession <= 0 { @@ -122,13 +123,22 @@ func (c *websocketToolOutputCache) cleanupLocked(now time.Time) { } } +func (c *websocketToolOutputCache) deleteSession(sessionKey string) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || c == nil { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.sessions, sessionKey) +} + func websocketDownstreamSessionKey(req *http.Request) string { if req == nil { return "" } - if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" { - return sessionID - } if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" { return requestID } @@ -137,9 +147,74 @@ func websocketDownstreamSessionKey(req *http.Request) string { return sessionID } } + if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" { + return sessionID + } return "" } +type websocketToolSessionRefCounter struct { + mu sync.Mutex + counts map[string]int +} + +func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter { + return &websocketToolSessionRefCounter{counts: make(map[string]int)} +} + +func (c *websocketToolSessionRefCounter) acquire(sessionKey string) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || c == nil { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.counts[sessionKey]++ +} + +func (c *websocketToolSessionRefCounter) release(sessionKey string) bool { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" || c == nil { + return false + } + + c.mu.Lock() + defer c.mu.Unlock() + + count := c.counts[sessionKey] + if count <= 1 { + delete(c.counts, sessionKey) + return true + } + c.counts[sessionKey] = count - 1 + return false +} + +func retainResponsesWebsocketToolCaches(sessionKey string) { + if defaultWebsocketToolSessionRefs == nil { + return + } + defaultWebsocketToolSessionRefs.acquire(sessionKey) +} + +func releaseResponsesWebsocketToolCaches(sessionKey string) { + if defaultWebsocketToolSessionRefs == nil { + return + } + if !defaultWebsocketToolSessionRefs.release(sessionKey) { + return + } + + if defaultWebsocketToolOutputCache != nil { + defaultWebsocketToolOutputCache.deleteSession(sessionKey) + } + if defaultWebsocketToolCallCache != nil { + defaultWebsocketToolCallCache.deleteSession(sessionKey) + } +} + func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte { return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload) }