From e399edd3cc9aaa5b42702f792df8a5aae9212206 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 27 May 2026 00:46:51 +0800 Subject: [PATCH] feat(images): add support for configurable GPT Image 2 base model and improved SSE handling - Introduced `GPTImage2BaseModel` configuration for hosted image generation tools with validation for "gpt-" prefix. - Added logic to dynamically resolve and apply the base model in Codex executor workflows. - Enhanced server-sent events (SSE) implementation with keep-alive tickers and error events for stream reliability. - Updated configuration file examples and internal documentation. --- config.example.yaml | 4 + internal/config/sdk_config.go | 7 + .../runtime/executor/codex_openai_images.go | 36 +- internal/watcher/diff/config_diff.go | 3 + .../handlers/openai/openai_images_handlers.go | 401 +++++++++++++----- 5 files changed, 326 insertions(+), 125 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 959f1f401..6a53c9400 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -100,6 +100,10 @@ disable-cooling: false # - "chat": disable image_generation injection on non-images endpoints, but keep /v1/images/generations and /v1/images/edits enabled. disable-image-generation: false +# Base model used when proxying gpt-image-2 via the hosted image_generation tool (Responses API). +# Must start with "gpt-" (case-insensitive). If unset or invalid, defaults to "gpt-5.4-mini". +# gpt-image-2-base-model: "gpt-5.4-mini" + # Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh). # When > 0, overrides the default worker count (16). # auth-auto-refresh-workers: 16 diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index 48c0fe5f1..d7a49e9d4 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -19,6 +19,13 @@ type SDKConfig struct { // while keeping /v1/images/generations and /v1/images/edits enabled and preserving image_generation there. DisableImageGeneration DisableImageGenerationMode `yaml:"disable-image-generation" json:"disable-image-generation"` + // GPTImage2BaseModel sets the base (mainline) model used when proxying GPT Image 2 + // requests via the hosted image_generation tool (e.g. Codex OAuth /v1/images/*). + // + // The value must start with "gpt-" (case-insensitive). If empty or invalid, the + // default base model ("gpt-5.4-mini") is used. + GPTImage2BaseModel string `yaml:"gpt-image-2-base-model,omitempty" json:"gpt-image-2-base-model,omitempty"` + // EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled. // Default is false for safety; when false, /v1internal:* requests are rejected. EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"` diff --git a/internal/runtime/executor/codex_openai_images.go b/internal/runtime/executor/codex_openai_images.go index 0db259e41..142971118 100644 --- a/internal/runtime/executor/codex_openai_images.go +++ b/internal/runtime/executor/codex_openai_images.go @@ -63,6 +63,20 @@ func codexIsImagesEndpointPath(path string) bool { return strings.HasSuffix(path, codexImagesGenerationsPath) || strings.HasSuffix(path, codexImagesEditsPath) } +func (e *CodexExecutor) resolveGPTImage2BaseModel() string { + if e == nil || e.cfg == nil { + return codexOpenAIImagesMainModel + } + model := strings.TrimSpace(e.cfg.GPTImage2BaseModel) + if model == "" { + return codexOpenAIImagesMainModel + } + if strings.HasPrefix(strings.ToLower(model), "gpt-") { + return model + } + return codexOpenAIImagesMainModel +} + func (e *CodexExecutor) executeOpenAIImage(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts) if errPrepare != nil { @@ -74,10 +88,11 @@ func (e *CodexExecutor) executeOpenAIImage(ctx context.Context, auth *cliproxyau baseURL = "https://chatgpt.com/backend-api/codex" } - reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth) + mainModel := e.resolveGPTImage2BaseModel() + reporter := helps.NewUsageReporter(ctx, e.Identifier(), mainModel, auth) defer reporter.TrackFailure(ctx, &err) - body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts) + body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts, mainModel) if errBuild != nil { return resp, errBuild } @@ -161,10 +176,11 @@ func (e *CodexExecutor) executeOpenAIImageStream(ctx context.Context, auth *clip baseURL = "https://chatgpt.com/backend-api/codex" } - reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth) + mainModel := e.resolveGPTImage2BaseModel() + reporter := helps.NewUsageReporter(ctx, e.Identifier(), mainModel, auth) defer reporter.TrackFailure(ctx, &err) - body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts) + body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts, mainModel) if errBuild != nil { return nil, errBuild } @@ -277,18 +293,22 @@ func (e *CodexExecutor) executeOpenAIImageStream(ctx context.Context, auth *clip return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } -func (e *CodexExecutor) prepareCodexOpenAIImageBody(body []byte, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) { +func (e *CodexExecutor) prepareCodexOpenAIImageBody(body []byte, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, mainModel string) ([]byte, error) { out := body + mainModel = strings.TrimSpace(mainModel) + if mainModel == "" { + mainModel = codexOpenAIImagesMainModel + } var errThinking error - out, errThinking = thinking.ApplyThinking(out, codexOpenAIImagesMainModel, codexOpenAIImageSourceFormat, "codex", e.Identifier()) + out, errThinking = thinking.ApplyThinking(out, mainModel, codexOpenAIImageSourceFormat, "codex", e.Identifier()) if errThinking != nil { return nil, errThinking } requestedModel := helps.PayloadRequestedModel(opts, req.Model) requestPath := helps.PayloadRequestPath(opts) - out = helps.ApplyPayloadConfigWithRequest(e.cfg, codexOpenAIImagesMainModel, "codex", codexOpenAIImageSourceFormat, "", out, body, requestedModel, requestPath, opts.Headers) - out, _ = sjson.SetBytes(out, "model", codexOpenAIImagesMainModel) + out = helps.ApplyPayloadConfigWithRequest(e.cfg, mainModel, "codex", codexOpenAIImageSourceFormat, "", out, body, requestedModel, requestPath, opts.Headers) + out, _ = sjson.SetBytes(out, "model", mainModel) out, _ = sjson.SetBytes(out, "stream", true) out, _ = sjson.DeleteBytes(out, "previous_response_id") out, _ = sjson.DeleteBytes(out, "prompt_cache_retention") diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index dcfa595f6..beda1be85 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -48,6 +48,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration { changes = append(changes, fmt.Sprintf("disable-image-generation: %v -> %v", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration)) } + if strings.TrimSpace(oldCfg.GPTImage2BaseModel) != strings.TrimSpace(newCfg.GPTImage2BaseModel) { + changes = append(changes, fmt.Sprintf("gpt-image-2-base-model: %s -> %s", strings.TrimSpace(oldCfg.GPTImage2BaseModel), strings.TrimSpace(newCfg.GPTImage2BaseModel))) + } if oldCfg.RequestLog != newCfg.RequestLog { changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog)) } diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go index 067471f4d..479dd3e6b 100644 --- a/sdk/api/handlers/openai/openai_images_handlers.go +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -56,6 +56,80 @@ type xaiImageResult struct { MimeType string } +type imagesStreamExecutionResult struct { + Data <-chan []byte + UpstreamHeaders http.Header + Errs <-chan *interfaces.ErrorMessage +} + +func setImagesSSEHeaders(c *gin.Context) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") +} + +func (h *OpenAIAPIHandler) newImagesStreamKeepAliveTicker() (*time.Ticker, <-chan time.Time) { + if h == nil || h.BaseAPIHandler == nil { + return nil, nil + } + interval := handlers.StreamingKeepAliveInterval(h.Cfg) + if interval <= 0 { + return nil, nil + } + ticker := time.NewTicker(interval) + return ticker, ticker.C +} + +func writeImagesStreamKeepAlive(c *gin.Context, flusher http.Flusher) { + _, _ = c.Writer.Write([]byte(": keep-alive\n\n")) + flusher.Flush() +} + +func writeImagesStreamErrorEvent(c *gin.Context, errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return + } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + errText := http.StatusText(status) + if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { + errText = errMsg.Error.Error() + } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) +} + +func (h *OpenAIAPIHandler) waitImagesStreamExecution(c *gin.Context, flusher http.Flusher, execute func() imagesStreamExecutionResult) (imagesStreamExecutionResult, bool, bool) { + resultChan := make(chan imagesStreamExecutionResult, 1) + go func() { + resultChan <- execute() + }() + + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + defer func() { + if keepAlive != nil { + keepAlive.Stop() + } + }() + + streamStarted := false + for { + select { + case <-c.Request.Context().Done(): + return imagesStreamExecutionResult{}, streamStarted, true + case result := <-resultChan: + return result, streamStarted, false + case <-keepAliveC: + setImagesSSEHeaders(c) + writeImagesStreamKeepAlive(c, flusher) + streamStarted = true + } + } +} + func (a *sseFrameAccumulator) AddChunk(chunk []byte) [][]byte { if len(chunk) == 0 { return nil @@ -1109,14 +1183,26 @@ func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, i cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx = handlers.WithDisallowFreeAuth(cliCtx) model := strings.TrimSpace(imageModel) - dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "") - - setSSEHeaders := func() { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") + execution, streamStarted, canceled := h.waitImagesStreamExecution(c, flusher, func() imagesStreamExecutionResult { + dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "") + return imagesStreamExecutionResult{Data: dataChan, UpstreamHeaders: upstreamHeaders, Errs: errChan} + }) + if canceled { + cliCancel(c.Request.Context().Err()) + return } + dataChan := execution.Data + upstreamHeaders := execution.UpstreamHeaders + errChan := execution.Errs + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + stopKeepAlive := func() { + if keepAlive != nil { + keepAlive.Stop() + keepAlive = nil + keepAliveC = nil + } + } + defer stopKeepAlive() for { select { @@ -1128,7 +1214,12 @@ func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, i errChan = nil continue } - h.WriteErrorResponse(c, errMsg) + if streamStarted { + writeImagesStreamErrorEvent(c, errMsg) + flusher.Flush() + } else { + h.WriteErrorResponse(c, errMsg) + } if errMsg != nil { cliCancel(errMsg.Error) } else { @@ -1137,7 +1228,8 @@ func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, i return case chunk, ok := <-dataChan: if !ok { - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() @@ -1145,35 +1237,30 @@ func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, i return } - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(chunk) flusher.Flush() + streamStarted = true h.forwardRawImageStream(cliCtx, c, func(err error) { cliCancel(err) }, dataChan, errChan) return + case <-keepAliveC: + setImagesSSEHeaders(c) + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + writeImagesStreamKeepAlive(c, flusher) + streamStarted = true } } } func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Context, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - emitError := func(errMsg *interfaces.ErrorMessage) { - if errMsg == nil { - return + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + defer func() { + if keepAlive != nil { + keepAlive.Stop() } - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - errText := http.StatusText(status) - if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { - errText = errMsg.Error.Error() - } - body := handlers.BuildErrorResponseBody(status, errText) - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - } + }() for { select { @@ -1185,7 +1272,10 @@ func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Con return case errMsg, ok := <-errs: if ok && errMsg != nil { - emitError(errMsg) + writeImagesStreamErrorEvent(c, errMsg) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } cancel(errMsg.Error) return } @@ -1199,6 +1289,10 @@ func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Con if flusher, ok := c.Writer.(http.Flusher); ok { flusher.Flush() } + case <-keepAliveC: + if flusher, ok := c.Writer.(http.Flusher); ok { + writeImagesStreamKeepAlive(c, flusher) + } } } } @@ -1217,14 +1311,26 @@ func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq [] cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) model := strings.TrimSpace(imageModel) - dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "") - - setSSEHeaders := func() { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") + execution, streamStarted, canceled := h.waitImagesStreamExecution(c, flusher, func() imagesStreamExecutionResult { + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "") + return imagesStreamExecutionResult{Data: dataChan, UpstreamHeaders: upstreamHeaders, Errs: errChan} + }) + if canceled { + cliCancel(c.Request.Context().Err()) + return } + dataChan := execution.Data + upstreamHeaders := execution.UpstreamHeaders + errChan := execution.Errs + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + stopKeepAlive := func() { + if keepAlive != nil { + keepAlive.Stop() + keepAlive = nil + keepAliveC = nil + } + } + defer stopKeepAlive() for { select { @@ -1236,7 +1342,12 @@ func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq [] errChan = nil continue } - h.WriteErrorResponse(c, errMsg) + if streamStarted { + writeImagesStreamErrorEvent(c, errMsg) + flusher.Flush() + } else { + h.WriteErrorResponse(c, errMsg) + } if errMsg != nil { cliCancel(errMsg.Error) } else { @@ -1245,38 +1356,34 @@ func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq [] return case chunk, ok := <-dataChan: if !ok { - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) flusher.Flush() cliCancel(nil) return } - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(chunk) flusher.Flush() + streamStarted = true h.ForwardStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, handlers.StreamForwardOptions{ WriteChunk: func(next []byte) { _, _ = c.Writer.Write(next) }, WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { - if errMsg == nil { - return - } - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - errText := http.StatusText(status) - if errMsg.Error != nil && errMsg.Error.Error() != "" { - errText = errMsg.Error.Error() - } - body := handlers.BuildErrorResponseBody(status, errText) - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + writeImagesStreamErrorEvent(c, errMsg) }, }) return + case <-keepAliveC: + setImagesSSEHeaders(c) + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + writeImagesStreamKeepAlive(c, flusher) + streamStarted = true } } } @@ -1337,57 +1444,96 @@ func (h *OpenAIAPIHandler) streamImagesWithModel(c *gin.Context, imageReq []byte cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) model = strings.TrimSpace(model) - resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "") - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - if errMsg.Error != nil { + type imageStreamResult struct { + resp []byte + upstreamHeaders http.Header + errMsg *interfaces.ErrorMessage + } + resultChan := make(chan imageStreamResult, 1) + go func() { + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "") + resultChan <- imageStreamResult{resp: resp, upstreamHeaders: upstreamHeaders, errMsg: errMsg} + }() + + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + stopKeepAlive := func() { + if keepAlive != nil { + keepAlive.Stop() + keepAlive = nil + keepAliveC = nil + } + } + defer stopKeepAlive() + streamStarted := false + writeError := func(errMsg *interfaces.ErrorMessage) { + if streamStarted { + writeImagesStreamErrorEvent(c, errMsg) + flusher.Flush() + } else { + h.WriteErrorResponse(c, errMsg) + } + if errMsg != nil && errMsg.Error != nil { cliCancel(errMsg.Error) } else { cliCancel(nil) } - return } - results, _, usageRaw, err := extractXAIImagesResponse(resp) - if err != nil { - errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} - h.WriteErrorResponse(c, errMsg) - cliCancel(err) - return - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) - - eventName := streamPrefix + ".completed" - responseFormat = normalizeImagesResponseFormat(responseFormat) - for _, img := range results { - data := []byte(`{"type":""}`) - data, _ = sjson.SetBytes(data, "type", eventName) - if responseFormat == "url" { - if img.URL != "" { - data, _ = sjson.SetBytes(data, "url", img.URL) - } else { - data, _ = sjson.SetBytes(data, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON) + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case <-keepAliveC: + setImagesSSEHeaders(c) + writeImagesStreamKeepAlive(c, flusher) + streamStarted = true + case result := <-resultChan: + stopKeepAlive() + if result.errMsg != nil { + writeError(result.errMsg) + return } - } else if img.B64JSON != "" { - data, _ = sjson.SetBytes(data, "b64_json", img.B64JSON) - } else { - data, _ = sjson.SetBytes(data, "url", img.URL) + + results, _, usageRaw, err := extractXAIImagesResponse(result.resp) + if err != nil { + writeError(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}) + return + } + + setImagesSSEHeaders(c) + handlers.WriteUpstreamHeaders(c.Writer.Header(), result.upstreamHeaders) + + eventName := streamPrefix + ".completed" + responseFormat = normalizeImagesResponseFormat(responseFormat) + for _, img := range results { + data := []byte(`{"type":""}`) + data, _ = sjson.SetBytes(data, "type", eventName) + if responseFormat == "url" { + if img.URL != "" { + data, _ = sjson.SetBytes(data, "url", img.URL) + } else { + data, _ = sjson.SetBytes(data, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON) + } + } else if img.B64JSON != "" { + data, _ = sjson.SetBytes(data, "b64_json", img.B64JSON) + } else { + data, _ = sjson.SetBytes(data, "url", img.URL) + } + if len(usageRaw) > 0 && json.Valid(usageRaw) { + data, _ = sjson.SetRawBytes(data, "usage", usageRaw) + } + if strings.TrimSpace(eventName) != "" { + _, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName) + } + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(data)) + flusher.Flush() + streamStarted = true + } + cliCancel(nil) + return } - if len(usageRaw) > 0 && json.Valid(usageRaw) { - data, _ = sjson.SetRawBytes(data, "usage", usageRaw) - } - if strings.TrimSpace(eventName) != "" { - _, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName) - } - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(data)) - flusher.Flush() } - cliCancel(nil) } func (h *OpenAIAPIHandler) collectImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string) { @@ -1593,14 +1739,26 @@ func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesRe if mainModel == "" { mainModel = defaultImagesMainModel } - dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, "openai-response", mainModel, responsesReq, "") - - setSSEHeaders := func() { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") + execution, streamStarted, canceled := h.waitImagesStreamExecution(c, flusher, func() imagesStreamExecutionResult { + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, "openai-response", mainModel, responsesReq, "") + return imagesStreamExecutionResult{Data: dataChan, UpstreamHeaders: upstreamHeaders, Errs: errChan} + }) + if canceled { + cliCancel(c.Request.Context().Err()) + return } + dataChan := execution.Data + upstreamHeaders := execution.UpstreamHeaders + errChan := execution.Errs + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + stopKeepAlive := func() { + if keepAlive != nil { + keepAlive.Stop() + keepAlive = nil + keepAliveC = nil + } + } + defer stopKeepAlive() writeEvent := func(eventName string, dataJSON []byte) { if strings.TrimSpace(eventName) != "" { @@ -1610,7 +1768,7 @@ func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesRe flusher.Flush() } - // Peek for first chunk/error so we can still return a JSON error body. + // Peek for the first chunk/error while still allowing configured SSE heartbeats. for { select { case <-c.Request.Context().Done(): @@ -1621,7 +1779,12 @@ func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesRe errChan = nil continue } - h.WriteErrorResponse(c, errMsg) + if streamStarted { + writeImagesStreamErrorEvent(c, errMsg) + flusher.Flush() + } else { + h.WriteErrorResponse(c, errMsg) + } if errMsg != nil { cliCancel(errMsg.Error) } else { @@ -1630,7 +1793,8 @@ func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesRe return case chunk, ok := <-dataChan: if !ok { - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() @@ -1638,11 +1802,17 @@ func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesRe return } - setSSEHeaders() + stopKeepAlive() + setImagesSSEHeaders(c) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) h.forwardImagesStream(cliCtx, c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, chunk, responseFormat, streamPrefix, writeEvent) return + case <-keepAliveC: + setImagesSSEHeaders(c) + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + writeImagesStreamKeepAlive(c, flusher) + streamStarted = true } } } @@ -1654,21 +1824,16 @@ func (h *OpenAIAPIHandler) forwardImagesStream(ctx context.Context, c *gin.Conte if responseFormat == "" { responseFormat = "b64_json" } + keepAlive, keepAliveC := h.newImagesStreamKeepAliveTicker() + defer func() { + if keepAlive != nil { + keepAlive.Stop() + } + }() emitError := func(errMsg *interfaces.ErrorMessage) { - if errMsg == nil { - return - } - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - errText := http.StatusText(status) - if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { - errText = errMsg.Error.Error() - } - body := handlers.BuildErrorResponseBody(status, errText) - writeEvent("error", body) + writeImagesStreamErrorEvent(c, errMsg) + flusher.Flush() } processFrame := func(frame []byte) (done bool) { @@ -1768,6 +1933,8 @@ func (h *OpenAIAPIHandler) forwardImagesStream(ctx context.Context, c *gin.Conte return } } + case <-keepAliveC: + writeImagesStreamKeepAlive(c, flusher) } } }