mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-23 04:39:53 +08:00
feat(api, xai): integrate xAI Grok image models and extend API endpoints for image support
- Added new xAI Grok image models (`grok-imagine-image`, `grok-imagine-image-quality`) with high-fidelity and aspect ratio configurations. - Extended `isSupportedImagesModel` logic to validate xAI models. - Implemented API request builders for image generation/editing with customizable options (e.g., resolution, aspect ratio, response format). - Enhanced `/v1/images` endpoints to handle xAI model capabilities, including response normalization and model-specific handlers. - Updated unit tests to validate xAI model validation, request structure, and API integration.
This commit is contained in:
10
README.md
10
README.md
@@ -2,7 +2,7 @@
|
||||
|
||||
English | [中文](README_CN.md) | [日本語](README_JA.md)
|
||||
|
||||
A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
|
||||
A proxy server that provides OpenAI/Gemini/Claude/Codex/Grok compatible API interfaces for CLI.
|
||||
|
||||
It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
|
||||
|
||||
@@ -41,20 +41,22 @@ VisionCoder is also offering our users a limited-time <a href="https://coder.vis
|
||||
|
||||
## Overview
|
||||
|
||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||
- OpenAI/Gemini/Claude/Grok compatible API endpoints for CLI models
|
||||
- OpenAI Codex support (GPT models) via OAuth login
|
||||
- Claude Code support via OAuth login
|
||||
- Grok Build support via OAuth login
|
||||
- Amp CLI and IDE extensions support with provider routing
|
||||
- Streaming, non-streaming, and WebSocket responses where supported
|
||||
- Function calling/tools support
|
||||
- Multimodal input support (text and images)
|
||||
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude)
|
||||
- Simple CLI authentication flows (Gemini, OpenAI, Claude)
|
||||
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Grok)
|
||||
- Simple CLI authentication flows (Gemini, OpenAI, Claude, Grok)
|
||||
- Generative Language API Key support
|
||||
- AI Studio Build multi-account load balancing
|
||||
- Gemini CLI multi-account load balancing
|
||||
- Claude Code multi-account load balancing
|
||||
- OpenAI Codex multi-account load balancing
|
||||
- Grok Build multi-account load balancing
|
||||
- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
|
||||
- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
|
||||
|
||||
|
||||
10
README_CN.md
10
README_CN.md
@@ -2,7 +2,7 @@
|
||||
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
|
||||
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。
|
||||
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex/Grok 兼容 API 接口的代理服务器。
|
||||
|
||||
现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。
|
||||
|
||||
@@ -42,19 +42,21 @@ VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" t
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex/Grok 兼容的 API 端点
|
||||
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
|
||||
- 新增 Claude Code 支持(OAuth 登录)
|
||||
- 新增 Grok Build 支持(OAuth 登录)
|
||||
- 支持流式、非流式响应,以及受支持场景下的 WebSocket 响应
|
||||
- 函数调用/工具支持
|
||||
- 多模态输入(文本、图片)
|
||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude)
|
||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude)
|
||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Grok)
|
||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Grok)
|
||||
- 支持 Gemini AIStudio API 密钥
|
||||
- 支持 AI Studio Build 多账户轮询
|
||||
- 支持 Gemini CLI 多账户轮询
|
||||
- 支持 Claude Code 多账户轮询
|
||||
- 支持 OpenAI Codex 多账户轮询
|
||||
- 支持 Grok Build 多账户轮询
|
||||
- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter)
|
||||
- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`)
|
||||
|
||||
|
||||
10
README_JA.md
10
README_JA.md
@@ -2,7 +2,7 @@
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex/Grok互換APIインターフェースを提供するプロキシサーバーです。
|
||||
|
||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||
|
||||
@@ -39,20 +39,22 @@ PackyCodeは当ソフトウェアのユーザーに特別割引を提供して
|
||||
|
||||
## 概要
|
||||
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude/Grok互換APIエンドポイント
|
||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||
- OAuthログインによるClaude Codeサポート
|
||||
- OAuthログインによるGrok Buildサポート
|
||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||
- ストリーミング、非ストリーミング、および対応環境でのWebSocketレスポンス
|
||||
- 関数呼び出し/ツールのサポート
|
||||
- マルチモーダル入力サポート(テキストと画像)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、Grok)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、Grok)
|
||||
- Generative Language APIキーのサポート
|
||||
- AI Studioビルドのマルチアカウント負荷分散
|
||||
- Gemini CLIのマルチアカウント負荷分散
|
||||
- Claude Codeのマルチアカウント負荷分散
|
||||
- OpenAI Codexのマルチアカウント負荷分散
|
||||
- Grok Buildのマルチアカウント負荷分散
|
||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||
|
||||
|
||||
@@ -6,7 +6,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const codexBuiltinImageModelID = "gpt-image-2"
|
||||
const (
|
||||
codexBuiltinImageModelID = "gpt-image-2"
|
||||
xaiBuiltinImageModelID = "grok-imagine-image"
|
||||
xaiBuiltinImageQualityModelID = "grok-imagine-image-quality"
|
||||
)
|
||||
|
||||
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||
type staticModelsJSON struct {
|
||||
@@ -81,7 +85,7 @@ func GetAntigravityModels() []*ModelInfo {
|
||||
|
||||
// GetXAIModels returns the standard xAI Grok model definitions.
|
||||
func GetXAIModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().XAI)
|
||||
return WithXAIBuiltins(cloneModelInfos(getModels().XAI))
|
||||
}
|
||||
|
||||
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should
|
||||
@@ -91,6 +95,12 @@ func WithCodexBuiltins(models []*ModelInfo) []*ModelInfo {
|
||||
return upsertModelInfos(models, codexBuiltinImageModelInfo())
|
||||
}
|
||||
|
||||
// WithXAIBuiltins injects hard-coded xAI image model definitions that should
|
||||
// not depend on remote models.json updates.
|
||||
func WithXAIBuiltins(models []*ModelInfo) []*ModelInfo {
|
||||
return upsertModelInfos(models, xaiBuiltinImageModelInfo(), xaiBuiltinImageQualityModelInfo())
|
||||
}
|
||||
|
||||
func codexBuiltinImageModelInfo() *ModelInfo {
|
||||
return &ModelInfo{
|
||||
ID: codexBuiltinImageModelID,
|
||||
@@ -103,6 +113,32 @@ func codexBuiltinImageModelInfo() *ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func xaiBuiltinImageModelInfo() *ModelInfo {
|
||||
return &ModelInfo{
|
||||
ID: xaiBuiltinImageModelID,
|
||||
Object: "model",
|
||||
Created: 1735689600, // 2025-01-01
|
||||
OwnedBy: "xai",
|
||||
Type: "xai",
|
||||
DisplayName: "Grok Imagine Image",
|
||||
Name: xaiBuiltinImageModelID,
|
||||
Description: "xAI Grok image generation model.",
|
||||
}
|
||||
}
|
||||
|
||||
func xaiBuiltinImageQualityModelInfo() *ModelInfo {
|
||||
return &ModelInfo{
|
||||
ID: xaiBuiltinImageQualityModelID,
|
||||
Object: "model",
|
||||
Created: 1735689600, // 2025-01-01
|
||||
OwnedBy: "xai",
|
||||
Type: "xai",
|
||||
DisplayName: "Grok Imagine Image Quality",
|
||||
Name: xaiBuiltinImageQualityModelID,
|
||||
Description: "xAI Grok higher-fidelity image generation model.",
|
||||
}
|
||||
}
|
||||
|
||||
func upsertModelInfos(models []*ModelInfo, extras ...*ModelInfo) []*ModelInfo {
|
||||
if len(extras) == 0 {
|
||||
return models
|
||||
|
||||
@@ -46,8 +46,7 @@
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh"
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -473,6 +472,30 @@
|
||||
"dynamic_allowed": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash-image",
|
||||
"object": "model",
|
||||
"created": 1763596800,
|
||||
"owned_by": "google",
|
||||
"type": "gemini",
|
||||
"display_name": "Gemini 2.5 Flash Image",
|
||||
"name": "models/gemini-2.5-flash-image",
|
||||
"version": "001",
|
||||
"description": "Our state-of-the-art image generation and editing model.",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": [
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent"
|
||||
],
|
||||
"thinking": {
|
||||
"max": 24576,
|
||||
"zero_allowed": true,
|
||||
"dynamic_allowed": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash-lite",
|
||||
"object": "model",
|
||||
@@ -1990,12 +2013,12 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.1-pro-high",
|
||||
"id": "gemini-pro-agent",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 3.1 Pro (High)",
|
||||
"name": "gemini-3.1-pro-high",
|
||||
"name": "gemini-pro-agent",
|
||||
"description": "Gemini 3.1 Pro (High)",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65535,
|
||||
|
||||
@@ -27,6 +27,13 @@ import (
|
||||
|
||||
var xaiDataTag = []byte("data:")
|
||||
|
||||
const (
|
||||
xaiImageHandlerType = "openai-image"
|
||||
xaiImagesGenerationsPath = "/images/generations"
|
||||
xaiImagesEditsPath = "/images/edits"
|
||||
xaiDefaultImageEndpointPath = xaiImagesGenerationsPath
|
||||
)
|
||||
|
||||
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
|
||||
type XAIExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -76,6 +83,10 @@ func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if endpointPath := xaiImageEndpointPath(opts); endpointPath != "" {
|
||||
return e.executeImages(ctx, auth, req, endpointPath)
|
||||
}
|
||||
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
@@ -151,6 +162,51 @@ func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, endpointPath string) (resp cliproxyexecutor.Response, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
if endpointPath == "" {
|
||||
endpointPath = xaiDefaultImageEndpointPath
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpointPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(req.Payload))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyXAIHeaders(httpReq, auth, token, false, "")
|
||||
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), req.Payload)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
@@ -454,6 +510,21 @@ func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.O
|
||||
return ""
|
||||
}
|
||||
|
||||
func xaiImageEndpointPath(opts cliproxyexecutor.Options) string {
|
||||
if opts.SourceFormat.String() != xaiImageHandlerType {
|
||||
return ""
|
||||
}
|
||||
|
||||
path := xaiMetadataString(opts.Metadata, cliproxyexecutor.RequestPathMetadataKey)
|
||||
if strings.HasSuffix(path, "/images/edits") {
|
||||
return xaiImagesEditsPath
|
||||
}
|
||||
if strings.HasSuffix(path, "/images/generations") {
|
||||
return xaiImagesGenerationsPath
|
||||
}
|
||||
return xaiDefaultImageEndpointPath
|
||||
}
|
||||
|
||||
func xaiMetadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 || key == "" {
|
||||
return ""
|
||||
|
||||
@@ -136,3 +136,96 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
|
||||
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorExecuteImagesUsesImagesEndpoint(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotAuth string
|
||||
var gotAccept string
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-imagine-image",
|
||||
Payload: []byte(`{"model":"grok-imagine-image","prompt":"draw"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-image"),
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/images/generations" {
|
||||
t.Fatalf("path = %q, want /images/generations", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer xai-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
|
||||
}
|
||||
if gotAccept != "application/json" {
|
||||
t.Fatalf("Accept = %q, want application/json", gotAccept)
|
||||
}
|
||||
if string(gotBody) != `{"model":"grok-imagine-image","prompt":"draw"}` {
|
||||
t.Fatalf("body = %s", string(gotBody))
|
||||
}
|
||||
if gjson.GetBytes(resp.Payload, "data.0.b64_json").String() != "AA==" {
|
||||
t.Fatalf("payload = %s", string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorExecuteImagesUsesEditsEndpoint(t *testing.T) {
|
||||
var gotPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"created":123,"data":[{"url":"https://x.ai/image.png"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{"base_url": server.URL},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-imagine-image",
|
||||
Payload: []byte(`{"model":"grok-imagine-image","prompt":"edit","image":{"type":"image_url","url":"https://example.com/a.png"}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-image"),
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/edits",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/images/edits" {
|
||||
t.Fatalf("path = %q, want /images/edits", gotPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultImagesMainModel = "gpt-5.4-mini"
|
||||
defaultImagesToolModel = "gpt-image-2"
|
||||
imagesGenerationsPath = "/v1/images/generations"
|
||||
imagesEditsPath = "/v1/images/edits"
|
||||
defaultImagesMainModel = "gpt-5.4-mini"
|
||||
defaultImagesToolModel = "gpt-image-2"
|
||||
defaultXAIImagesModel = "grok-imagine-image"
|
||||
xaiImagesQualityModel = "grok-imagine-image-quality"
|
||||
xaiImagesHandlerType = "openai-image"
|
||||
xaiImagesDefaultAspectRatio = "1:1"
|
||||
xaiImagesDefaultResolution = "1k"
|
||||
imagesGenerationsPath = "/v1/images/generations"
|
||||
imagesEditsPath = "/v1/images/edits"
|
||||
)
|
||||
|
||||
type imageCallResult struct {
|
||||
@@ -42,6 +47,13 @@ type sseFrameAccumulator struct {
|
||||
pending []byte
|
||||
}
|
||||
|
||||
type xaiImageResult struct {
|
||||
B64JSON string
|
||||
URL string
|
||||
RevisedPrompt string
|
||||
MimeType string
|
||||
}
|
||||
|
||||
func (a *sseFrameAccumulator) AddChunk(chunk []byte) [][]byte {
|
||||
if len(chunk) == 0 {
|
||||
return nil
|
||||
@@ -102,12 +114,36 @@ func (a *sseFrameAccumulator) Flush() [][]byte {
|
||||
return frames
|
||||
}
|
||||
|
||||
func isSupportedImagesModel(model string) bool {
|
||||
baseModel := strings.TrimSpace(model)
|
||||
if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 {
|
||||
baseModel = strings.TrimSpace(baseModel[idx+1:])
|
||||
func imagesModelParts(model string) (prefix string, baseModel string) {
|
||||
model = strings.TrimSpace(model)
|
||||
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
|
||||
return strings.TrimSpace(model[:idx]), strings.TrimSpace(model[idx+1:])
|
||||
}
|
||||
return baseModel == defaultImagesToolModel
|
||||
return "", model
|
||||
}
|
||||
|
||||
func imagesModelBase(model string) string {
|
||||
_, baseModel := imagesModelParts(model)
|
||||
return strings.ToLower(strings.TrimSpace(baseModel))
|
||||
}
|
||||
|
||||
func isXAIImagesModel(model string) bool {
|
||||
prefix, baseModel := imagesModelParts(model)
|
||||
baseModel = strings.ToLower(strings.TrimSpace(baseModel))
|
||||
if baseModel != defaultXAIImagesModel && baseModel != xaiImagesQualityModel {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix = strings.ToLower(strings.TrimSpace(prefix))
|
||||
return prefix == "" || prefix == "xai" || prefix == "x-ai" || prefix == "grok"
|
||||
}
|
||||
|
||||
func isSupportedImagesModel(model string) bool {
|
||||
baseModel := imagesModelBase(model)
|
||||
if baseModel == defaultImagesToolModel {
|
||||
return true
|
||||
}
|
||||
return isXAIImagesModel(model)
|
||||
}
|
||||
|
||||
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
|
||||
@@ -117,13 +153,182 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
|
||||
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel),
|
||||
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeImagesResponseFormat(responseFormat string) string {
|
||||
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||
return "url"
|
||||
}
|
||||
return "b64_json"
|
||||
}
|
||||
|
||||
func canonicalXAIImagesModel(model string) string {
|
||||
baseModel := imagesModelBase(model)
|
||||
if baseModel == xaiImagesQualityModel {
|
||||
return xaiImagesQualityModel
|
||||
}
|
||||
return defaultXAIImagesModel
|
||||
}
|
||||
|
||||
func xaiImagesAspectRatio(raw string, fallback string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1:1", "square":
|
||||
return "1:1"
|
||||
case "16:9", "landscape":
|
||||
return "16:9"
|
||||
case "9:16", "portrait":
|
||||
return "9:16"
|
||||
case "4:3":
|
||||
return "4:3"
|
||||
case "3:4":
|
||||
return "3:4"
|
||||
case "3:2":
|
||||
return "3:2"
|
||||
case "2:3":
|
||||
return "2:3"
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func xaiImagesAspectRatioFromSize(size string, fallback string) string {
|
||||
size = strings.ToLower(strings.TrimSpace(size))
|
||||
switch size {
|
||||
case "1024x1024", "2048x2048", "1:1":
|
||||
return "1:1"
|
||||
case "1792x1024", "16:9":
|
||||
return "16:9"
|
||||
case "1024x1792", "9:16":
|
||||
return "9:16"
|
||||
case "1536x1024", "3:2":
|
||||
return "3:2"
|
||||
case "1024x1536", "2:3":
|
||||
return "2:3"
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func xaiImagesResolution(raw string, size string, fallback string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1k", "2k":
|
||||
return strings.ToLower(strings.TrimSpace(raw))
|
||||
}
|
||||
if strings.Contains(strings.ToLower(strings.TrimSpace(size)), "2048") {
|
||||
return "2k"
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func xaiImagesRef(imageURL string) []byte {
|
||||
ref := []byte(`{"type":"image_url","url":""}`)
|
||||
ref, _ = sjson.SetBytes(ref, "url", strings.TrimSpace(imageURL))
|
||||
return ref
|
||||
}
|
||||
|
||||
func buildXAIImagesBaseRequest(model string, prompt string, responseFormat string, aspectRatio string, resolution string, n int64) []byte {
|
||||
req := []byte(`{}`)
|
||||
req, _ = sjson.SetBytes(req, "model", canonicalXAIImagesModel(model))
|
||||
req, _ = sjson.SetBytes(req, "prompt", strings.TrimSpace(prompt))
|
||||
req, _ = sjson.SetBytes(req, "response_format", normalizeImagesResponseFormat(responseFormat))
|
||||
if aspectRatio != "" {
|
||||
req, _ = sjson.SetBytes(req, "aspect_ratio", aspectRatio)
|
||||
}
|
||||
if resolution != "" {
|
||||
req, _ = sjson.SetBytes(req, "resolution", resolution)
|
||||
}
|
||||
if n > 0 {
|
||||
req, _ = sjson.SetBytes(req, "n", n)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func buildXAIImagesGenerationsRequest(rawJSON []byte, model string, responseFormat string) []byte {
|
||||
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
|
||||
size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String())
|
||||
aspectRatio := xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "")
|
||||
aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio)
|
||||
if aspectRatio == "" {
|
||||
aspectRatio = xaiImagesDefaultAspectRatio
|
||||
}
|
||||
resolution := xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, xaiImagesDefaultResolution)
|
||||
n := int64(0)
|
||||
if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number {
|
||||
n = v.Int()
|
||||
}
|
||||
return buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n)
|
||||
}
|
||||
|
||||
func buildXAIImagesEditRequest(model string, prompt string, images []string, responseFormat string, aspectRatio string, resolution string, n int64) []byte {
|
||||
req := buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n)
|
||||
trimmedImages := make([]string, 0, len(images))
|
||||
for _, img := range images {
|
||||
if strings.TrimSpace(img) != "" {
|
||||
trimmedImages = append(trimmedImages, strings.TrimSpace(img))
|
||||
}
|
||||
}
|
||||
if len(trimmedImages) == 1 {
|
||||
req, _ = sjson.SetRawBytes(req, "image", xaiImagesRef(trimmedImages[0]))
|
||||
return req
|
||||
}
|
||||
for _, img := range trimmedImages {
|
||||
req, _ = sjson.SetRawBytes(req, "images.-1", xaiImagesRef(img))
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func collectXAIImagesFromJSON(rawJSON []byte) []string {
|
||||
var images []string
|
||||
appendImage := func(url string) {
|
||||
url = strings.TrimSpace(url)
|
||||
if url != "" {
|
||||
images = append(images, url)
|
||||
}
|
||||
}
|
||||
|
||||
if image := gjson.GetBytes(rawJSON, "image"); image.Exists() {
|
||||
if image.Type == gjson.String {
|
||||
appendImage(image.String())
|
||||
} else if image.Type == gjson.JSON {
|
||||
appendImage(image.Get("image_url.url").String())
|
||||
if imageURL := image.Get("image_url"); imageURL.Type == gjson.String {
|
||||
appendImage(imageURL.String())
|
||||
}
|
||||
appendImage(image.Get("url").String())
|
||||
}
|
||||
}
|
||||
if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() {
|
||||
for _, img := range imagesResult.Array() {
|
||||
if img.Type == gjson.String {
|
||||
appendImage(img.String())
|
||||
continue
|
||||
}
|
||||
appendImage(img.Get("image_url.url").String())
|
||||
if imageURL := img.Get("image_url"); imageURL.Type == gjson.String {
|
||||
appendImage(imageURL.String())
|
||||
}
|
||||
appendImage(img.Get("url").String())
|
||||
}
|
||||
}
|
||||
return images
|
||||
}
|
||||
|
||||
func xaiImagesEditOptionsFromJSON(rawJSON []byte) (aspectRatio string, resolution string, n int64) {
|
||||
size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String())
|
||||
aspectRatio = xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "")
|
||||
aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio)
|
||||
resolution = xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, "")
|
||||
if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number {
|
||||
n = v.Int()
|
||||
}
|
||||
return aspectRatio, resolution, n
|
||||
}
|
||||
|
||||
func mimeTypeFromOutputFormat(outputFormat string) string {
|
||||
if outputFormat == "" {
|
||||
return "image/png"
|
||||
@@ -249,6 +454,12 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
}
|
||||
stream := gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
|
||||
if isXAIImagesModel(imageModel) {
|
||||
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
|
||||
return
|
||||
}
|
||||
|
||||
tool := []byte(`{"type":"image_generation","action":"generate"}`)
|
||||
tool, _ = sjson.SetBytes(tool, "model", imageModel)
|
||||
|
||||
@@ -372,6 +583,22 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
images = append(images, dataURL)
|
||||
}
|
||||
|
||||
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
}
|
||||
stream := parseBoolField(c.PostForm("stream"), false)
|
||||
|
||||
if isXAIImagesModel(imageModel) {
|
||||
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
|
||||
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
|
||||
resolution := xaiImagesResolution(c.PostForm("resolution"), c.PostForm("size"), "")
|
||||
n := parseIntField(c.PostForm("n"), 0)
|
||||
xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n)
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
|
||||
var maskDataURL *string
|
||||
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
|
||||
dataURL, err := multipartFileToDataURL(maskFiles[0])
|
||||
@@ -387,12 +614,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
maskDataURL = &dataURL
|
||||
}
|
||||
|
||||
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
}
|
||||
stream := parseBoolField(c.PostForm("stream"), false)
|
||||
|
||||
tool := []byte(`{"type":"image_generation","action":"edit"}`)
|
||||
tool, _ = sjson.SetBytes(tool, "model", imageModel)
|
||||
|
||||
@@ -474,6 +695,29 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
}
|
||||
stream := gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
|
||||
if isXAIImagesModel(imageModel) {
|
||||
images := collectXAIImagesFromJSON(rawJSON)
|
||||
if len(images) == 0 {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Invalid request: image is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
aspectRatio, resolution, n := xaiImagesEditOptionsFromJSON(rawJSON)
|
||||
xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n)
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
|
||||
var images []string
|
||||
imagesResult := gjson.GetBytes(rawJSON, "images")
|
||||
if imagesResult.IsArray() {
|
||||
@@ -511,12 +755,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
}
|
||||
stream := gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
|
||||
tool := []byte(`{"type":"image_generation","action":"edit"}`)
|
||||
tool, _ = sjson.SetBytes(tool, "model", imageModel)
|
||||
|
||||
@@ -580,6 +818,191 @@ func buildImagesResponsesRequest(prompt string, images []string, toolJSON []byte
|
||||
return req
|
||||
}
|
||||
|
||||
func extractXAIImagesResponse(payload []byte) (results []xaiImageResult, createdAt int64, usageRaw []byte, err error) {
|
||||
if !json.Valid(payload) {
|
||||
return nil, 0, nil, fmt.Errorf("upstream returned invalid image response JSON")
|
||||
}
|
||||
|
||||
createdAt = gjson.GetBytes(payload, "created").Int()
|
||||
if createdAt <= 0 {
|
||||
createdAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
data := gjson.GetBytes(payload, "data")
|
||||
if data.IsArray() {
|
||||
for _, item := range data.Array() {
|
||||
result := xaiImageResult{
|
||||
B64JSON: strings.TrimSpace(item.Get("b64_json").String()),
|
||||
URL: strings.TrimSpace(item.Get("url").String()),
|
||||
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
|
||||
MimeType: strings.TrimSpace(item.Get("mime_type").String()),
|
||||
}
|
||||
if result.MimeType == "" {
|
||||
result.MimeType = mimeTypeFromOutputFormat(strings.TrimSpace(item.Get("output_format").String()))
|
||||
}
|
||||
if result.MimeType == "" {
|
||||
result.MimeType = "image/png"
|
||||
}
|
||||
if result.B64JSON == "" && result.URL == "" {
|
||||
continue
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, 0, nil, fmt.Errorf("upstream did not return image output")
|
||||
}
|
||||
|
||||
if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() {
|
||||
usageRaw = []byte(usage.Raw)
|
||||
}
|
||||
|
||||
return results, createdAt, usageRaw, nil
|
||||
}
|
||||
|
||||
func buildImagesAPIResponseFromXAI(payload []byte, responseFormat string) ([]byte, error) {
|
||||
results, createdAt, usageRaw, err := extractXAIImagesResponse(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := []byte(`{"created":0,"data":[]}`)
|
||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||
responseFormat = normalizeImagesResponseFormat(responseFormat)
|
||||
|
||||
for _, img := range results {
|
||||
item := []byte(`{}`)
|
||||
if responseFormat == "url" {
|
||||
if img.URL != "" {
|
||||
item, _ = sjson.SetBytes(item, "url", img.URL)
|
||||
} else {
|
||||
item, _ = sjson.SetBytes(item, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON)
|
||||
}
|
||||
} else if img.B64JSON != "" {
|
||||
item, _ = sjson.SetBytes(item, "b64_json", img.B64JSON)
|
||||
} else {
|
||||
item, _ = sjson.SetBytes(item, "url", img.URL)
|
||||
}
|
||||
if img.RevisedPrompt != "" {
|
||||
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "data.-1", item)
|
||||
}
|
||||
|
||||
if len(usageRaw) > 0 && json.Valid(usageRaw) {
|
||||
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string, stream bool) {
|
||||
if stream {
|
||||
h.streamXAIImages(c, xaiReq, responseFormat, streamPrefix)
|
||||
return
|
||||
}
|
||||
h.collectXAIImages(c, xaiReq, responseFormat)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg.Error != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
out, err := buildImagesAPIResponseFromXAI(resp, responseFormat)
|
||||
if err != nil {
|
||||
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(err)
|
||||
return
|
||||
}
|
||||
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(out)
|
||||
cliCancel(nil)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg.Error != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
results, _, usageRaw, err := extractXAIImagesResponse(resp)
|
||||
if err != nil {
|
||||
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
eventName := streamPrefix + ".completed"
|
||||
responseFormat = normalizeImagesResponseFormat(responseFormat)
|
||||
for _, img := range results {
|
||||
data := []byte(`{"type":""}`)
|
||||
data, _ = sjson.SetBytes(data, "type", eventName)
|
||||
if responseFormat == "url" {
|
||||
if img.URL != "" {
|
||||
data, _ = sjson.SetBytes(data, "url", img.URL)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON)
|
||||
}
|
||||
} else if img.B64JSON != "" {
|
||||
data, _ = sjson.SetBytes(data, "b64_json", img.B64JSON)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "url", img.URL)
|
||||
}
|
||||
if len(usageRaw) > 0 && json.Valid(usageRaw) {
|
||||
data, _ = sjson.SetRawBytes(data, "usage", usageRaw)
|
||||
}
|
||||
if strings.TrimSpace(eventName) != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName)
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
|
||||
flusher.Flush()
|
||||
}
|
||||
cliCancel(nil)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) collectImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
|
||||
}
|
||||
|
||||
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + "."
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
|
||||
if message != expectedMessage {
|
||||
t.Fatalf("error message = %q, want %q", message, expectedMessage)
|
||||
}
|
||||
@@ -49,8 +49,8 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesModelValidationAllowsGPTImage2WithOptionalPrefix(t *testing.T) {
|
||||
for _, model := range []string{"gpt-image-2", "codex/gpt-image-2"} {
|
||||
func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
|
||||
for _, model := range []string{"gpt-image-2", "codex/gpt-image-2", "grok-imagine-image", "xai/grok-imagine-image", "grok-imagine-image-quality", "xai/grok-imagine-image-quality"} {
|
||||
if !isSupportedImagesModel(model) {
|
||||
t.Fatalf("expected %s to be supported", model)
|
||||
}
|
||||
@@ -58,6 +58,90 @@ func TestImagesModelValidationAllowsGPTImage2WithOptionalPrefix(t *testing.T) {
|
||||
if isSupportedImagesModel("gpt-5.4-mini") {
|
||||
t.Fatal("expected gpt-5.4-mini to be rejected")
|
||||
}
|
||||
if isSupportedImagesModel("codex/grok-imagine-image") {
|
||||
t.Fatal("expected codex/grok-imagine-image to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
|
||||
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
|
||||
|
||||
req := buildXAIImagesGenerationsRequest(rawJSON, "xai/grok-imagine-image-quality", "url")
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "grok-imagine-image-quality" {
|
||||
t.Fatalf("model = %q, want grok-imagine-image-quality", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "prompt").String(); got != "abstract art" {
|
||||
t.Fatalf("prompt = %q, want abstract art", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "aspect_ratio").String(); got != "16:9" {
|
||||
t.Fatalf("aspect_ratio = %q, want 16:9", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "resolution").String(); got != "2k" {
|
||||
t.Fatalf("resolution = %q, want 2k", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "response_format").String(); got != "url" {
|
||||
t.Fatalf("response_format = %q, want url", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "n").Int(); got != 2 {
|
||||
t.Fatalf("n = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIImagesEditRequest(t *testing.T) {
|
||||
req := buildXAIImagesEditRequest("grok-imagine-image", "edit it", []string{"data:image/png;base64,AA==", "https://example.com/image.png"}, "b64_json", "3:2", "1k", 0)
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "grok-imagine-image" {
|
||||
t.Fatalf("model = %q, want grok-imagine-image", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "images.0.type").String(); got != "image_url" {
|
||||
t.Fatalf("images.0.type = %q, want image_url", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "images.0.url").String(); got != "data:image/png;base64,AA==" {
|
||||
t.Fatalf("images.0.url = %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "images.1.url").String(); got != "https://example.com/image.png" {
|
||||
t.Fatalf("images.1.url = %q", got)
|
||||
}
|
||||
if gjson.GetBytes(req, "image").Exists() {
|
||||
t.Fatalf("multiple image edits must use images array: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
|
||||
req := buildXAIImagesEditRequest("grok-imagine-image", "edit it", []string{"https://example.com/image.png"}, "url", "", "", 0)
|
||||
|
||||
if got := gjson.GetBytes(req, "image.type").String(); got != "image_url" {
|
||||
t.Fatalf("image.type = %q, want image_url", got)
|
||||
}
|
||||
if got := gjson.GetBytes(req, "image.url").String(); got != "https://example.com/image.png" {
|
||||
t.Fatalf("image.url = %q", got)
|
||||
}
|
||||
if gjson.GetBytes(req, "images").Exists() {
|
||||
t.Fatalf("single image edit must use image object: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
|
||||
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)
|
||||
|
||||
out, err := buildImagesAPIResponseFromXAI(payload, "b64_json")
|
||||
if err != nil {
|
||||
t.Fatalf("buildImagesAPIResponseFromXAI() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(out, "created").Int(); got != 123 {
|
||||
t.Fatalf("created = %d, want 123", got)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "data.0.b64_json").String(); got != "AA==" {
|
||||
t.Fatalf("data.0.b64_json = %q, want AA==", got)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "data.0.revised_prompt").String(); got != "refined" {
|
||||
t.Fatalf("data.0.revised_prompt = %q, want refined", got)
|
||||
}
|
||||
if !gjson.GetBytes(out, "usage").Exists() {
|
||||
t.Fatalf("usage missing: %s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesGenerationsRejectsUnsupportedModel(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user