diff --git a/README.md b/README.md index 8064db7d7..8ad0d9dc8 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ English | [中文](README_CN.md) | [日本語](README_JA.md) -A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI. +A proxy server that provides OpenAI/Gemini/Claude/Codex/Grok compatible API interfaces for CLI. It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth. @@ -41,20 +41,22 @@ VisionCoder is also offering our users a limited-time = 300 { + helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + + return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil +} + func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { token, baseURL := xaiCreds(auth) if baseURL == "" { @@ -454,6 +510,21 @@ func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.O return "" } +func xaiImageEndpointPath(opts cliproxyexecutor.Options) string { + if opts.SourceFormat.String() != xaiImageHandlerType { + return "" + } + + path := xaiMetadataString(opts.Metadata, cliproxyexecutor.RequestPathMetadataKey) + if strings.HasSuffix(path, "/images/edits") { + return xaiImagesEditsPath + } + if strings.HasSuffix(path, "/images/generations") { + return xaiImagesGenerationsPath + } + return xaiDefaultImageEndpointPath +} + func xaiMetadataString(meta map[string]any, key string) string { if len(meta) == 0 || key == "" { return "" diff --git a/internal/runtime/executor/xai_executor_test.go b/internal/runtime/executor/xai_executor_test.go index a08d512bf..1a517f75b 100644 --- a/internal/runtime/executor/xai_executor_test.go +++ b/internal/runtime/executor/xai_executor_test.go @@ -136,3 +136,96 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) { t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody)) } } + +func TestXAIExecutorExecuteImagesUsesImagesEndpoint(t *testing.T) { + var gotPath string + var gotAuth string + var gotAccept string + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotAccept = r.Header.Get("Accept") + var errRead error + gotBody, errRead = io.ReadAll(r.Body) + if errRead != nil { + t.Fatalf("read body: %v", errRead) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}]}`)) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{ + "base_url": server.URL, + "auth_kind": "oauth", + }, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-imagine-image", + Payload: []byte(`{"model":"grok-imagine-image","prompt":"draw"}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-image"), + Metadata: map[string]any{ + cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations", + }, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotPath != "/images/generations" { + t.Fatalf("path = %q, want /images/generations", gotPath) + } + if gotAuth != "Bearer xai-token" { + t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth) + } + if gotAccept != "application/json" { + t.Fatalf("Accept = %q, want application/json", gotAccept) + } + if string(gotBody) != `{"model":"grok-imagine-image","prompt":"draw"}` { + t.Fatalf("body = %s", string(gotBody)) + } + if gjson.GetBytes(resp.Payload, "data.0.b64_json").String() != "AA==" { + t.Fatalf("payload = %s", string(resp.Payload)) + } +} + +func TestXAIExecutorExecuteImagesUsesEditsEndpoint(t *testing.T) { + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"created":123,"data":[{"url":"https://x.ai/image.png"}]}`)) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{"base_url": server.URL}, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-imagine-image", + Payload: []byte(`{"model":"grok-imagine-image","prompt":"edit","image":{"type":"image_url","url":"https://example.com/a.png"}}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-image"), + Metadata: map[string]any{ + cliproxyexecutor.RequestPathMetadataKey: "/v1/images/edits", + }, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotPath != "/images/edits" { + t.Fatalf("path = %q, want /images/edits", gotPath) + } +} diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go index 72f06093c..34bdbcdc9 100644 --- a/sdk/api/handlers/openai/openai_images_handlers.go +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -23,10 +23,15 @@ import ( ) const ( - defaultImagesMainModel = "gpt-5.4-mini" - defaultImagesToolModel = "gpt-image-2" - imagesGenerationsPath = "/v1/images/generations" - imagesEditsPath = "/v1/images/edits" + defaultImagesMainModel = "gpt-5.4-mini" + defaultImagesToolModel = "gpt-image-2" + defaultXAIImagesModel = "grok-imagine-image" + xaiImagesQualityModel = "grok-imagine-image-quality" + xaiImagesHandlerType = "openai-image" + xaiImagesDefaultAspectRatio = "1:1" + xaiImagesDefaultResolution = "1k" + imagesGenerationsPath = "/v1/images/generations" + imagesEditsPath = "/v1/images/edits" ) type imageCallResult struct { @@ -42,6 +47,13 @@ type sseFrameAccumulator struct { pending []byte } +type xaiImageResult struct { + B64JSON string + URL string + RevisedPrompt string + MimeType string +} + func (a *sseFrameAccumulator) AddChunk(chunk []byte) [][]byte { if len(chunk) == 0 { return nil @@ -102,12 +114,36 @@ func (a *sseFrameAccumulator) Flush() [][]byte { return frames } -func isSupportedImagesModel(model string) bool { - baseModel := strings.TrimSpace(model) - if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 { - baseModel = strings.TrimSpace(baseModel[idx+1:]) +func imagesModelParts(model string) (prefix string, baseModel string) { + model = strings.TrimSpace(model) + if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 { + return strings.TrimSpace(model[:idx]), strings.TrimSpace(model[idx+1:]) } - return baseModel == defaultImagesToolModel + return "", model +} + +func imagesModelBase(model string) string { + _, baseModel := imagesModelParts(model) + return strings.ToLower(strings.TrimSpace(baseModel)) +} + +func isXAIImagesModel(model string) bool { + prefix, baseModel := imagesModelParts(model) + baseModel = strings.ToLower(strings.TrimSpace(baseModel)) + if baseModel != defaultXAIImagesModel && baseModel != xaiImagesQualityModel { + return false + } + + prefix = strings.ToLower(strings.TrimSpace(prefix)) + return prefix == "" || prefix == "xai" || prefix == "x-ai" || prefix == "grok" +} + +func isSupportedImagesModel(model string) bool { + baseModel := imagesModelBase(model) + if baseModel == defaultImagesToolModel { + return true + } + return isXAIImagesModel(model) } func rejectUnsupportedImagesModel(c *gin.Context, model string) bool { @@ -117,13 +153,182 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool { c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel), + Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel), Type: "invalid_request_error", }, }) return true } +func normalizeImagesResponseFormat(responseFormat string) string { + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + return "url" + } + return "b64_json" +} + +func canonicalXAIImagesModel(model string) string { + baseModel := imagesModelBase(model) + if baseModel == xaiImagesQualityModel { + return xaiImagesQualityModel + } + return defaultXAIImagesModel +} + +func xaiImagesAspectRatio(raw string, fallback string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1:1", "square": + return "1:1" + case "16:9", "landscape": + return "16:9" + case "9:16", "portrait": + return "9:16" + case "4:3": + return "4:3" + case "3:4": + return "3:4" + case "3:2": + return "3:2" + case "2:3": + return "2:3" + default: + return fallback + } +} + +func xaiImagesAspectRatioFromSize(size string, fallback string) string { + size = strings.ToLower(strings.TrimSpace(size)) + switch size { + case "1024x1024", "2048x2048", "1:1": + return "1:1" + case "1792x1024", "16:9": + return "16:9" + case "1024x1792", "9:16": + return "9:16" + case "1536x1024", "3:2": + return "3:2" + case "1024x1536", "2:3": + return "2:3" + default: + return fallback + } +} + +func xaiImagesResolution(raw string, size string, fallback string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1k", "2k": + return strings.ToLower(strings.TrimSpace(raw)) + } + if strings.Contains(strings.ToLower(strings.TrimSpace(size)), "2048") { + return "2k" + } + return fallback +} + +func xaiImagesRef(imageURL string) []byte { + ref := []byte(`{"type":"image_url","url":""}`) + ref, _ = sjson.SetBytes(ref, "url", strings.TrimSpace(imageURL)) + return ref +} + +func buildXAIImagesBaseRequest(model string, prompt string, responseFormat string, aspectRatio string, resolution string, n int64) []byte { + req := []byte(`{}`) + req, _ = sjson.SetBytes(req, "model", canonicalXAIImagesModel(model)) + req, _ = sjson.SetBytes(req, "prompt", strings.TrimSpace(prompt)) + req, _ = sjson.SetBytes(req, "response_format", normalizeImagesResponseFormat(responseFormat)) + if aspectRatio != "" { + req, _ = sjson.SetBytes(req, "aspect_ratio", aspectRatio) + } + if resolution != "" { + req, _ = sjson.SetBytes(req, "resolution", resolution) + } + if n > 0 { + req, _ = sjson.SetBytes(req, "n", n) + } + return req +} + +func buildXAIImagesGenerationsRequest(rawJSON []byte, model string, responseFormat string) []byte { + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) + size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String()) + aspectRatio := xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "") + aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio) + if aspectRatio == "" { + aspectRatio = xaiImagesDefaultAspectRatio + } + resolution := xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, xaiImagesDefaultResolution) + n := int64(0) + if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number { + n = v.Int() + } + return buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n) +} + +func buildXAIImagesEditRequest(model string, prompt string, images []string, responseFormat string, aspectRatio string, resolution string, n int64) []byte { + req := buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n) + trimmedImages := make([]string, 0, len(images)) + for _, img := range images { + if strings.TrimSpace(img) != "" { + trimmedImages = append(trimmedImages, strings.TrimSpace(img)) + } + } + if len(trimmedImages) == 1 { + req, _ = sjson.SetRawBytes(req, "image", xaiImagesRef(trimmedImages[0])) + return req + } + for _, img := range trimmedImages { + req, _ = sjson.SetRawBytes(req, "images.-1", xaiImagesRef(img)) + } + return req +} + +func collectXAIImagesFromJSON(rawJSON []byte) []string { + var images []string + appendImage := func(url string) { + url = strings.TrimSpace(url) + if url != "" { + images = append(images, url) + } + } + + if image := gjson.GetBytes(rawJSON, "image"); image.Exists() { + if image.Type == gjson.String { + appendImage(image.String()) + } else if image.Type == gjson.JSON { + appendImage(image.Get("image_url.url").String()) + if imageURL := image.Get("image_url"); imageURL.Type == gjson.String { + appendImage(imageURL.String()) + } + appendImage(image.Get("url").String()) + } + } + if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() { + for _, img := range imagesResult.Array() { + if img.Type == gjson.String { + appendImage(img.String()) + continue + } + appendImage(img.Get("image_url.url").String()) + if imageURL := img.Get("image_url"); imageURL.Type == gjson.String { + appendImage(imageURL.String()) + } + appendImage(img.Get("url").String()) + } + } + return images +} + +func xaiImagesEditOptionsFromJSON(rawJSON []byte) (aspectRatio string, resolution string, n int64) { + size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String()) + aspectRatio = xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "") + aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio) + resolution = xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, "") + if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number { + n = v.Int() + } + return aspectRatio, resolution, n +} + func mimeTypeFromOutputFormat(outputFormat string) string { if outputFormat == "" { return "image/png" @@ -249,6 +454,12 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { } stream := gjson.GetBytes(rawJSON, "stream").Bool() + if isXAIImagesModel(imageModel) { + xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat) + h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream) + return + } + tool := []byte(`{"type":"image_generation","action":"generate"}`) tool, _ = sjson.SetBytes(tool, "model", imageModel) @@ -372,6 +583,22 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { images = append(images, dataURL) } + responseFormat := strings.TrimSpace(c.PostForm("response_format")) + if responseFormat == "" { + responseFormat = "b64_json" + } + stream := parseBoolField(c.PostForm("stream"), false) + + if isXAIImagesModel(imageModel) { + aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "") + aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio) + resolution := xaiImagesResolution(c.PostForm("resolution"), c.PostForm("size"), "") + n := parseIntField(c.PostForm("n"), 0) + xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n) + h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream) + return + } + var maskDataURL *string if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil { dataURL, err := multipartFileToDataURL(maskFiles[0]) @@ -387,12 +614,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { maskDataURL = &dataURL } - responseFormat := strings.TrimSpace(c.PostForm("response_format")) - if responseFormat == "" { - responseFormat = "b64_json" - } - stream := parseBoolField(c.PostForm("stream"), false) - tool := []byte(`{"type":"image_generation","action":"edit"}`) tool, _ = sjson.SetBytes(tool, "model", imageModel) @@ -474,6 +695,29 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { return } + responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) + if responseFormat == "" { + responseFormat = "b64_json" + } + stream := gjson.GetBytes(rawJSON, "stream").Bool() + + if isXAIImagesModel(imageModel) { + images := collectXAIImagesFromJSON(rawJSON) + if len(images) == 0 { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: image is required", + Type: "invalid_request_error", + }, + }) + return + } + aspectRatio, resolution, n := xaiImagesEditOptionsFromJSON(rawJSON) + xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n) + h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream) + return + } + var images []string imagesResult := gjson.GetBytes(rawJSON, "images") if imagesResult.IsArray() { @@ -511,12 +755,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { return } - responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) - if responseFormat == "" { - responseFormat = "b64_json" - } - stream := gjson.GetBytes(rawJSON, "stream").Bool() - tool := []byte(`{"type":"image_generation","action":"edit"}`) tool, _ = sjson.SetBytes(tool, "model", imageModel) @@ -580,6 +818,191 @@ func buildImagesResponsesRequest(prompt string, images []string, toolJSON []byte return req } +func extractXAIImagesResponse(payload []byte) (results []xaiImageResult, createdAt int64, usageRaw []byte, err error) { + if !json.Valid(payload) { + return nil, 0, nil, fmt.Errorf("upstream returned invalid image response JSON") + } + + createdAt = gjson.GetBytes(payload, "created").Int() + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + data := gjson.GetBytes(payload, "data") + if data.IsArray() { + for _, item := range data.Array() { + result := xaiImageResult{ + B64JSON: strings.TrimSpace(item.Get("b64_json").String()), + URL: strings.TrimSpace(item.Get("url").String()), + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + MimeType: strings.TrimSpace(item.Get("mime_type").String()), + } + if result.MimeType == "" { + result.MimeType = mimeTypeFromOutputFormat(strings.TrimSpace(item.Get("output_format").String())) + } + if result.MimeType == "" { + result.MimeType = "image/png" + } + if result.B64JSON == "" && result.URL == "" { + continue + } + results = append(results, result) + } + } + if len(results) == 0 { + return nil, 0, nil, fmt.Errorf("upstream did not return image output") + } + + if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() { + usageRaw = []byte(usage.Raw) + } + + return results, createdAt, usageRaw, nil +} + +func buildImagesAPIResponseFromXAI(payload []byte, responseFormat string) ([]byte, error) { + results, createdAt, usageRaw, err := extractXAIImagesResponse(payload) + if err != nil { + return nil, err + } + + out := []byte(`{"created":0,"data":[]}`) + out, _ = sjson.SetBytes(out, "created", createdAt) + responseFormat = normalizeImagesResponseFormat(responseFormat) + + for _, img := range results { + item := []byte(`{}`) + if responseFormat == "url" { + if img.URL != "" { + item, _ = sjson.SetBytes(item, "url", img.URL) + } else { + item, _ = sjson.SetBytes(item, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON) + } + } else if img.B64JSON != "" { + item, _ = sjson.SetBytes(item, "b64_json", img.B64JSON) + } else { + item, _ = sjson.SetBytes(item, "url", img.URL) + } + if img.RevisedPrompt != "" { + item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt) + } + out, _ = sjson.SetRawBytes(out, "data.-1", item) + } + + if len(usageRaw) > 0 && json.Valid(usageRaw) { + out, _ = sjson.SetRawBytes(out, "usage", usageRaw) + } + + return out, nil +} + +func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string, stream bool) { + if stream { + h.streamXAIImages(c, xaiReq, responseFormat, streamPrefix) + return + } + h.collectXAIImages(c, xaiReq, responseFormat) +} + +func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) { + c.Header("Content-Type", "application/json") + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + + model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String()) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "") + stopKeepAlive() + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + + out, err := buildImagesAPIResponseFromXAI(resp, responseFormat) + if err != nil { + errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} + h.WriteErrorResponse(c, errMsg) + cliCancel(err) + return + } + + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write(out) + cliCancel(nil) +} + +func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) { + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String()) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "") + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + + results, _, usageRaw, err := extractXAIImagesResponse(resp) + if err != nil { + errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} + h.WriteErrorResponse(c, errMsg) + cliCancel(err) + return + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + + eventName := streamPrefix + ".completed" + responseFormat = normalizeImagesResponseFormat(responseFormat) + for _, img := range results { + data := []byte(`{"type":""}`) + data, _ = sjson.SetBytes(data, "type", eventName) + if responseFormat == "url" { + if img.URL != "" { + data, _ = sjson.SetBytes(data, "url", img.URL) + } else { + data, _ = sjson.SetBytes(data, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON) + } + } else if img.B64JSON != "" { + data, _ = sjson.SetBytes(data, "b64_json", img.B64JSON) + } else { + data, _ = sjson.SetBytes(data, "url", img.URL) + } + if len(usageRaw) > 0 && json.Valid(usageRaw) { + data, _ = sjson.SetRawBytes(data, "usage", usageRaw) + } + if strings.TrimSpace(eventName) != "" { + _, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName) + } + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(data)) + flusher.Flush() + } + cliCancel(nil) +} + func (h *OpenAIAPIHandler) collectImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string) { c.Header("Content-Type", "application/json") diff --git a/sdk/api/handlers/openai/openai_images_handlers_test.go b/sdk/api/handlers/openai/openai_images_handlers_test.go index 779659961..57df272ac 100644 --- a/sdk/api/handlers/openai/openai_images_handlers_test.go +++ b/sdk/api/handlers/openai/openai_images_handlers_test.go @@ -40,7 +40,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR } message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String() - expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + "." + expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "." if message != expectedMessage { t.Fatalf("error message = %q, want %q", message, expectedMessage) } @@ -49,8 +49,8 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR } } -func TestImagesModelValidationAllowsGPTImage2WithOptionalPrefix(t *testing.T) { - for _, model := range []string{"gpt-image-2", "codex/gpt-image-2"} { +func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) { + for _, model := range []string{"gpt-image-2", "codex/gpt-image-2", "grok-imagine-image", "xai/grok-imagine-image", "grok-imagine-image-quality", "xai/grok-imagine-image-quality"} { if !isSupportedImagesModel(model) { t.Fatalf("expected %s to be supported", model) } @@ -58,6 +58,90 @@ func TestImagesModelValidationAllowsGPTImage2WithOptionalPrefix(t *testing.T) { if isSupportedImagesModel("gpt-5.4-mini") { t.Fatal("expected gpt-5.4-mini to be rejected") } + if isSupportedImagesModel("codex/grok-imagine-image") { + t.Fatal("expected codex/grok-imagine-image to be rejected") + } +} + +func TestBuildXAIImagesGenerationsRequest(t *testing.T) { + rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`) + + req := buildXAIImagesGenerationsRequest(rawJSON, "xai/grok-imagine-image-quality", "url") + + if got := gjson.GetBytes(req, "model").String(); got != "grok-imagine-image-quality" { + t.Fatalf("model = %q, want grok-imagine-image-quality", got) + } + if got := gjson.GetBytes(req, "prompt").String(); got != "abstract art" { + t.Fatalf("prompt = %q, want abstract art", got) + } + if got := gjson.GetBytes(req, "aspect_ratio").String(); got != "16:9" { + t.Fatalf("aspect_ratio = %q, want 16:9", got) + } + if got := gjson.GetBytes(req, "resolution").String(); got != "2k" { + t.Fatalf("resolution = %q, want 2k", got) + } + if got := gjson.GetBytes(req, "response_format").String(); got != "url" { + t.Fatalf("response_format = %q, want url", got) + } + if got := gjson.GetBytes(req, "n").Int(); got != 2 { + t.Fatalf("n = %d, want 2", got) + } +} + +func TestBuildXAIImagesEditRequest(t *testing.T) { + req := buildXAIImagesEditRequest("grok-imagine-image", "edit it", []string{"data:image/png;base64,AA==", "https://example.com/image.png"}, "b64_json", "3:2", "1k", 0) + + if got := gjson.GetBytes(req, "model").String(); got != "grok-imagine-image" { + t.Fatalf("model = %q, want grok-imagine-image", got) + } + if got := gjson.GetBytes(req, "images.0.type").String(); got != "image_url" { + t.Fatalf("images.0.type = %q, want image_url", got) + } + if got := gjson.GetBytes(req, "images.0.url").String(); got != "data:image/png;base64,AA==" { + t.Fatalf("images.0.url = %q", got) + } + if got := gjson.GetBytes(req, "images.1.url").String(); got != "https://example.com/image.png" { + t.Fatalf("images.1.url = %q", got) + } + if gjson.GetBytes(req, "image").Exists() { + t.Fatalf("multiple image edits must use images array: %s", string(req)) + } +} + +func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) { + req := buildXAIImagesEditRequest("grok-imagine-image", "edit it", []string{"https://example.com/image.png"}, "url", "", "", 0) + + if got := gjson.GetBytes(req, "image.type").String(); got != "image_url" { + t.Fatalf("image.type = %q, want image_url", got) + } + if got := gjson.GetBytes(req, "image.url").String(); got != "https://example.com/image.png" { + t.Fatalf("image.url = %q", got) + } + if gjson.GetBytes(req, "images").Exists() { + t.Fatalf("single image edit must use image object: %s", string(req)) + } +} + +func TestBuildImagesAPIResponseFromXAI(t *testing.T) { + payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`) + + out, err := buildImagesAPIResponseFromXAI(payload, "b64_json") + if err != nil { + t.Fatalf("buildImagesAPIResponseFromXAI() error = %v", err) + } + + if got := gjson.GetBytes(out, "created").Int(); got != 123 { + t.Fatalf("created = %d, want 123", got) + } + if got := gjson.GetBytes(out, "data.0.b64_json").String(); got != "AA==" { + t.Fatalf("data.0.b64_json = %q, want AA==", got) + } + if got := gjson.GetBytes(out, "data.0.revised_prompt").String(); got != "refined" { + t.Fatalf("data.0.revised_prompt = %q, want refined", got) + } + if !gjson.GetBytes(out, "usage").Exists() { + t.Fatalf("usage missing: %s", string(out)) + } } func TestImagesGenerationsRejectsUnsupportedModel(t *testing.T) {