feat(api): add OpenAI compatibility for image models

- Introduced OpenAI-compatible image model support in the API, enabling integration through image generation and editing endpoints.
- Added registry type for OpenAIImageModelType to classify and validate compatibility.
- Implemented request handling for OpenAI-compatible image models, including JSON and multipart formats.
- Enhanced executor methods to support OpenAI-compatible image streaming and non-streaming requests.
- Included tests to validate model registration, streaming behavior, and multipart payload formatting.
This commit is contained in:
Luis Pater
2026-05-19 09:36:05 +08:00
parent b67eb6f25d
commit feebe6c7f2
16 changed files with 1962 additions and 37 deletions

View File

@@ -277,6 +277,7 @@ nonstream-keepalive-interval: 0
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# image: false # optional: set true to allow this model on /v1/images/generations and /v1/images/edits
# thinking: # optional: omit to default to levels ["low","medium","high"]
# levels: ["low", "medium", "high"]
# # You may repeat the same alias to build an internal model pool.

View File

@@ -585,6 +585,9 @@ type OpenAICompatibilityModel struct {
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
// Image marks this model as callable through /v1/images/generations and /v1/images/edits.
Image bool `yaml:"image,omitempty" json:"image,omitempty"`
// Thinking configures the thinking/reasoning capability for this model.
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`

View File

@@ -15,6 +15,9 @@ import (
log "github.com/sirupsen/logrus"
)
// OpenAIImageModelType marks models that are callable through OpenAI-compatible image endpoints.
const OpenAIImageModelType = "openai-image"
// ModelInfo represents information about an available model
type ModelInfo struct {
// ID is the unique identifier for the model

View File

@@ -147,6 +147,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if opts.Alt == "responses/compact" {
return e.executeCompact(ctx, auth, req, opts)
}
if isCodexOpenAIImageRequest(opts) {
return e.executeOpenAIImage(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
@@ -397,6 +400,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
if isCodexOpenAIImageRequest(opts) {
return e.executeOpenAIImageStream(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)

View File

@@ -0,0 +1,678 @@
package executor
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
codexOpenAIImageSourceFormat = "openai-image"
codexImagesGenerationsPath = "/v1/images/generations"
codexImagesEditsPath = "/v1/images/edits"
codexOpenAIImagesMainModel = "gpt-5.4-mini"
)
type codexOpenAIImagePreparedRequest struct {
Body []byte
ResponseFormat string
StreamPrefix string
}
type codexImageCallResult struct {
Result string
RevisedPrompt string
OutputFormat string
Size string
Background string
Quality string
}
func isCodexOpenAIImageRequest(opts cliproxyexecutor.Options) bool {
if !strings.EqualFold(strings.TrimSpace(opts.SourceFormat.String()), codexOpenAIImageSourceFormat) {
return false
}
return codexIsImagesEndpointPath(helps.PayloadRequestPath(opts))
}
func codexIsImagesEndpointPath(path string) bool {
path = strings.TrimSpace(path)
if path == codexImagesGenerationsPath || path == codexImagesEditsPath {
return true
}
return strings.HasSuffix(path, codexImagesGenerationsPath) || strings.HasSuffix(path, codexImagesEditsPath)
}
func (e *CodexExecutor) executeOpenAIImage(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
if errPrepare != nil {
return resp, errPrepare
}
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
defer reporter.TrackFailure(ctx, &err)
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
if errBuild != nil {
return resp, errBuild
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
if errCache != nil {
return resp, errCache
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return resp, err
}
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range bytes.Split(data, []byte("\n")) {
if !bytes.HasPrefix(line, dataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(dataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
publishCodexImageToolUsage(ctx, reporter, body, eventData)
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
results, createdAt, usageRaw, firstMeta, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
if errExtract != nil {
return resp, errExtract
}
if len(results) == 0 {
return resp, statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"}
}
out, errOutput := codexBuildImagesAPIResponse(results, createdAt, usageRaw, firstMeta, prepared.ResponseFormat)
if errOutput != nil {
return resp, errOutput
}
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
}
err = statusErr{code: http.StatusGatewayTimeout, msg: "stream error: stream disconnected before completion"}
return resp, err
}
func (e *CodexExecutor) executeOpenAIImageStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
if errPrepare != nil {
return nil, errPrepare
}
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
defer reporter.TrackFailure(ctx, &err)
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
if errBuild != nil {
return nil, errBuild
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
if errCache != nil {
return nil, errCache
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
sendPayload := func(payload []byte) bool {
select {
case out <- cliproxyexecutor.StreamChunk{Payload: payload}:
return true
case <-ctx.Done():
return false
}
}
sendError := func(errSend error) bool {
select {
case out <- cliproxyexecutor.StreamChunk{Err: errSend}:
return true
case <-ctx.Done():
return false
}
}
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for scanner.Scan() {
line := scanner.Bytes()
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if !bytes.HasPrefix(line, dataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(dataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.image_generation_call.partial_image":
frame := codexBuildImagePartialFrame(eventData, prepared.ResponseFormat, prepared.StreamPrefix)
if len(frame) > 0 && !sendPayload(frame) {
return
}
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
publishCodexImageToolUsage(ctx, reporter, body, eventData)
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
results, _, usageRaw, _, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
if errExtract != nil {
sendError(errExtract)
return
}
if len(results) == 0 {
sendError(statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"})
return
}
for _, img := range results {
frame := codexBuildImageCompletedFrame(img, usageRaw, prepared.ResponseFormat, prepared.StreamPrefix)
if len(frame) > 0 && !sendPayload(frame) {
return
}
}
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx, errScan)
sendError(errScan)
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) prepareCodexOpenAIImageBody(body []byte, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) {
out := body
var errThinking error
out, errThinking = thinking.ApplyThinking(out, codexOpenAIImagesMainModel, codexOpenAIImageSourceFormat, "codex", e.Identifier())
if errThinking != nil {
return nil, errThinking
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
out = helps.ApplyPayloadConfigWithRequest(e.cfg, codexOpenAIImagesMainModel, "codex", codexOpenAIImageSourceFormat, "", out, body, requestedModel, requestPath, opts.Headers)
out, _ = sjson.SetBytes(out, "model", codexOpenAIImagesMainModel)
out, _ = sjson.SetBytes(out, "stream", true)
out, _ = sjson.DeleteBytes(out, "previous_response_id")
out, _ = sjson.DeleteBytes(out, "prompt_cache_retention")
out, _ = sjson.DeleteBytes(out, "safety_identifier")
out, _ = sjson.DeleteBytes(out, "stream_options")
return normalizeCodexInstructions(out), nil
}
func recordCodexOpenAIImageRequest(ctx context.Context, cfg *config.Config, provider string, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: headers,
Body: body,
Provider: provider,
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
func codexPrepareOpenAIImageRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (codexOpenAIImagePreparedRequest, error) {
path := helps.PayloadRequestPath(opts)
if strings.HasSuffix(path, codexImagesGenerationsPath) {
return codexPrepareOpenAIImageGenerationJSON(req.Payload, req.Model)
}
if !strings.HasSuffix(path, codexImagesEditsPath) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("unsupported OpenAI image endpoint path %q", path)
}
contentType := codexImageContentType(opts.Headers)
mediaType, _, _ := mime.ParseMediaType(contentType)
if strings.HasPrefix(strings.ToLower(mediaType), "multipart/") {
return codexPrepareOpenAIImageEditMultipart(req.Payload, req.Model, contentType)
}
return codexPrepareOpenAIImageEditJSON(req.Payload, req.Model)
}
func codexPrepareOpenAIImageGenerationJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
if !json.Valid(rawJSON) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image generation request JSON")
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "generate", []string{"size", "quality", "background", "output_format", "moderation"}, []string{"output_compression", "partial_images"})
body := codexBuildImagesResponsesRequest(prompt, nil, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
StreamPrefix: "image_generation",
}, nil
}
func codexPrepareOpenAIImageEditJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
if !json.Valid(rawJSON) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image edit request JSON")
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
images := make([]string, 0)
if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() {
for _, img := range imagesResult.Array() {
url := strings.TrimSpace(img.Get("image_url").String())
if url != "" {
images = append(images, url)
}
}
}
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "edit", []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"}, []string{"output_compression", "partial_images"})
if mask := strings.TrimSpace(gjson.GetBytes(rawJSON, "mask.image_url").String()); mask != "" {
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", mask)
}
body := codexBuildImagesResponsesRequest(prompt, images, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
StreamPrefix: "image_edit",
}, nil
}
func codexPrepareOpenAIImageEditMultipart(rawBody []byte, routeModel string, contentType string) (codexOpenAIImagePreparedRequest, error) {
_, params, errMedia := mime.ParseMediaType(contentType)
if errMedia != nil {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart content type failed: %w", errMedia)
}
boundary := strings.TrimSpace(params["boundary"])
if boundary == "" {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("multipart boundary is required")
}
reader := multipart.NewReader(bytes.NewReader(rawBody), boundary)
form, errForm := reader.ReadForm(32 << 20)
if errForm != nil {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart form failed: %w", errForm)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
log.Errorf("codex openai images: remove multipart temp files error: %v", errRemove)
}
}()
prompt := strings.TrimSpace(codexFormValue(form, "prompt"))
responseFormat := codexNormalizeImageResponseFormat(codexFormValue(form, "response_format"))
tool := []byte(`{"type":"image_generation","action":"edit"}`)
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(codexFormValue(form, "model"), routeModel))
for _, field := range []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"} {
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
tool, _ = sjson.SetBytes(tool, field, value)
}
}
for _, field := range []string{"output_compression", "partial_images"} {
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
if parsed, errParse := strconv.ParseInt(value, 10, 64); errParse == nil {
tool, _ = sjson.SetBytes(tool, field, parsed)
}
}
}
images := make([]string, 0)
for _, fh := range codexMultipartImageFiles(form) {
dataURL, errData := codexMultipartFileToDataURL(fh)
if errData != nil {
return codexOpenAIImagePreparedRequest{}, errData
}
images = append(images, dataURL)
}
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
dataURL, errData := codexMultipartFileToDataURL(maskFiles[0])
if errData != nil {
return codexOpenAIImagePreparedRequest{}, errData
}
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", dataURL)
}
body := codexBuildImagesResponsesRequest(prompt, images, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: responseFormat,
StreamPrefix: "image_edit",
}, nil
}
func codexImageContentType(headers http.Header) string {
if headers == nil {
return ""
}
return strings.TrimSpace(headers.Get("Content-Type"))
}
func codexOpenAIImageResponseFormatFromJSON(rawJSON []byte) string {
return codexNormalizeImageResponseFormat(gjson.GetBytes(rawJSON, "response_format").String())
}
func codexNormalizeImageResponseFormat(responseFormat string) string {
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
return "url"
}
return "b64_json"
}
func codexOpenAIImageToolModel(requestModel string, routeModel string) string {
model := strings.TrimSpace(requestModel)
if model == "" {
model = strings.TrimSpace(routeModel)
}
if model == "" {
model = codexDefaultImageToolModel
}
return model
}
func codexBuildOpenAIImageTool(rawJSON []byte, routeModel string, action string, stringFields []string, numberFields []string) []byte {
tool := []byte(`{"type":"image_generation","action":""}`)
tool, _ = sjson.SetBytes(tool, "action", action)
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(gjson.GetBytes(rawJSON, "model").String(), routeModel))
for _, field := range stringFields {
if value := strings.TrimSpace(gjson.GetBytes(rawJSON, field).String()); value != "" {
tool, _ = sjson.SetBytes(tool, field, value)
}
}
for _, field := range numberFields {
if value := gjson.GetBytes(rawJSON, field); value.Exists() && value.Type == gjson.Number {
tool, _ = sjson.SetBytes(tool, field, value.Int())
}
}
return tool
}
func codexBuildImagesResponsesRequest(prompt string, images []string, toolJSON []byte) []byte {
req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
req, _ = sjson.SetBytes(req, "model", codexOpenAIImagesMainModel)
input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
contentIndex := 1
for _, img := range images {
if strings.TrimSpace(img) == "" {
continue
}
part := []byte(`{"type":"input_image","image_url":""}`)
part, _ = sjson.SetBytes(part, "image_url", img)
input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", contentIndex), part)
contentIndex++
}
req, _ = sjson.SetRawBytes(req, "input", input)
req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
if len(toolJSON) > 0 && json.Valid(toolJSON) {
req, _ = sjson.SetRawBytes(req, "tools.-1", toolJSON)
}
return req
}
func codexFormValue(form *multipart.Form, key string) string {
if form == nil || len(form.Value[key]) == 0 {
return ""
}
return strings.TrimSpace(form.Value[key][0])
}
func codexMultipartImageFiles(form *multipart.Form) []*multipart.FileHeader {
if form == nil {
return nil
}
if files := form.File["image[]"]; len(files) > 0 {
return files
}
return form.File["image"]
}
func codexMultipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
if fileHeader == nil {
return "", fmt.Errorf("upload file is nil")
}
f, errOpen := fileHeader.Open()
if errOpen != nil {
return "", fmt.Errorf("open upload file failed: %w", errOpen)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("codex openai images: close upload file error: %v", errClose)
}
}()
data, errRead := io.ReadAll(f)
if errRead != nil {
return "", fmt.Errorf("read upload file failed: %w", errRead)
}
mediaType := strings.TrimSpace(fileHeader.Header.Get("Content-Type"))
if mediaType == "" {
mediaType = http.DetectContentType(data)
}
return "data:" + mediaType + ";base64," + base64.StdEncoding.EncodeToString(data), nil
}
func codexExtractImagesFromResponsesCompleted(payload []byte) (results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, err error) {
if gjson.GetBytes(payload, "type").String() != "response.completed" {
return nil, 0, nil, codexImageCallResult{}, fmt.Errorf("unexpected event type")
}
createdAt = gjson.GetBytes(payload, "response.created_at").Int()
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
output := gjson.GetBytes(payload, "response.output")
if output.IsArray() {
for _, item := range output.Array() {
if item.Get("type").String() != "image_generation_call" {
continue
}
res := strings.TrimSpace(item.Get("result").String())
if res == "" {
continue
}
entry := codexImageCallResult{
Result: res,
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
Size: strings.TrimSpace(item.Get("size").String()),
Background: strings.TrimSpace(item.Get("background").String()),
Quality: strings.TrimSpace(item.Get("quality").String()),
}
if len(results) == 0 {
firstMeta = entry
}
results = append(results, entry)
}
}
if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
usageRaw = []byte(usage.Raw)
}
return results, createdAt, usageRaw, firstMeta, nil
}
func codexBuildImagesAPIResponse(results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, responseFormat string) ([]byte, error) {
out := []byte(`{"created":0,"data":[]}`)
out, _ = sjson.SetBytes(out, "created", createdAt)
responseFormat = codexNormalizeImageResponseFormat(responseFormat)
for _, img := range results {
item := []byte(`{}`)
if responseFormat == "url" {
item, _ = sjson.SetBytes(item, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
} else {
item, _ = sjson.SetBytes(item, "b64_json", img.Result)
}
if img.RevisedPrompt != "" {
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
}
out, _ = sjson.SetRawBytes(out, "data.-1", item)
}
if firstMeta.Background != "" {
out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
}
if firstMeta.OutputFormat != "" {
out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
}
if firstMeta.Quality != "" {
out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
}
if firstMeta.Size != "" {
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
}
return out, nil
}
func codexBuildImagePartialFrame(payload []byte, responseFormat string, streamPrefix string) []byte {
b64 := strings.TrimSpace(gjson.GetBytes(payload, "partial_image_b64").String())
if b64 == "" {
return nil
}
outputFormat := strings.TrimSpace(gjson.GetBytes(payload, "output_format").String())
eventName := strings.TrimSpace(streamPrefix) + ".partial_image"
data := []byte(`{"type":"","partial_image_index":0}`)
data, _ = sjson.SetBytes(data, "type", eventName)
data, _ = sjson.SetBytes(data, "partial_image_index", gjson.GetBytes(payload, "partial_image_index").Int())
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(outputFormat)+";base64,"+b64)
} else {
data, _ = sjson.SetBytes(data, "b64_json", b64)
}
return codexBuildSSEFrame(eventName, data)
}
func codexBuildImageCompletedFrame(img codexImageCallResult, usageRaw []byte, responseFormat string, streamPrefix string) []byte {
eventName := strings.TrimSpace(streamPrefix) + ".completed"
data := []byte(`{"type":""}`)
data, _ = sjson.SetBytes(data, "type", eventName)
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
} else {
data, _ = sjson.SetBytes(data, "b64_json", img.Result)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
data, _ = sjson.SetRawBytes(data, "usage", usageRaw)
}
return codexBuildSSEFrame(eventName, data)
}
func codexBuildSSEFrame(eventName string, data []byte) []byte {
var buf bytes.Buffer
if strings.TrimSpace(eventName) != "" {
buf.WriteString("event: ")
buf.WriteString(eventName)
buf.WriteString("\n")
}
buf.WriteString("data: ")
buf.Write(data)
buf.WriteString("\n\n")
return buf.Bytes()
}
func codexMimeTypeFromOutputFormat(outputFormat string) string {
switch strings.ToLower(strings.TrimSpace(outputFormat)) {
case "jpg", "jpeg":
return "image/jpeg"
case "webp":
return "image/webp"
default:
return "image/png"
}
}

View File

@@ -4,9 +4,13 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"time"
@@ -21,6 +25,14 @@ import (
"github.com/tidwall/sjson"
)
const (
openAICompatImageHandlerType = "openai-image"
openAICompatImagesGenerationsPath = "/images/generations"
openAICompatImagesEditsPath = "/images/edits"
openAICompatDefaultImageEndpoint = openAICompatImagesGenerationsPath
openAICompatMultipartMemory int64 = 32 << 20
)
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
// It performs request/response translation and executes against the provider base URL
// using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context.
@@ -71,6 +83,10 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
return e.executeImages(ctx, auth, req, opts, endpointPath)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -179,7 +195,98 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
return resp, nil
}
func (e *OpenAICompatExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return resp, err
}
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), false)
if errPrepare != nil {
err = errPrepare
return resp, err
}
if contentType == "" {
contentType = "application/json"
}
url := strings.TrimSuffix(baseURL, "/") + endpointPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", contentType)
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
body, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
err = statusErr{code: httpResp.StatusCode, msg: string(body)}
return resp, err
}
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
reporter.EnsurePublished(ctx)
resp = cliproxyexecutor.Response{Payload: body, Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
return e.executeImagesStream(ctx, auth, req, opts, endpointPath)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -342,6 +449,121 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) executeImagesStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return nil, err
}
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), true)
if errPrepare != nil {
err = errPrepare
return nil, err
}
if contentType == "" {
contentType = "application/json"
}
url := strings.TrimSuffix(baseURL, "/") + endpointPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", contentType)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
body, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
return nil, statusErr{code: httpResp.StatusCode, msg: string(body)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
reporter.EnsurePublished(ctx)
}()
buffer := make([]byte, 32*1024)
for {
n, errRead := httpResp.Body.Read(buffer)
if n > 0 {
chunk := bytes.Clone(buffer[:n])
helps.AppendAPIResponseChunk(ctx, e.cfg, chunk)
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
case <-ctx.Done():
return
}
}
if errRead != nil {
if errRead != io.EOF {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx, errRead)
select {
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
case <-ctx.Done():
}
}
return
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -380,6 +602,124 @@ func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.A
return auth, nil
}
func openAICompatImageEndpointPath(opts cliproxyexecutor.Options) string {
if opts.SourceFormat.String() != openAICompatImageHandlerType {
return ""
}
path := helps.PayloadRequestPath(opts)
if strings.HasSuffix(path, "/images/edits") {
return openAICompatImagesEditsPath
}
if strings.HasSuffix(path, "/images/generations") {
return openAICompatImagesGenerationsPath
}
return openAICompatDefaultImageEndpoint
}
func prepareOpenAICompatImagesPayload(payload []byte, model string, contentType string, stream bool) ([]byte, string, error) {
model = strings.TrimSpace(model)
contentType = strings.TrimSpace(contentType)
if json.Valid(payload) {
if model != "" {
payload, _ = sjson.SetBytes(payload, "model", model)
}
if stream {
payload, _ = sjson.SetBytes(payload, "stream", true)
} else {
payload, _ = sjson.DeleteBytes(payload, "stream")
}
return payload, "application/json", nil
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil || !strings.HasPrefix(strings.ToLower(strings.TrimSpace(mediaType)), "multipart/") {
return payload, contentType, nil
}
boundary := strings.TrimSpace(params["boundary"])
if boundary == "" {
return nil, "", fmt.Errorf("multipart boundary is missing")
}
return rewriteOpenAICompatImagesMultipartPayload(payload, model, boundary, stream)
}
func cloneOpenAICompatMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
dst := make(textproto.MIMEHeader, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func rewriteOpenAICompatImagesMultipartPayload(payload []byte, model string, boundary string, stream bool) ([]byte, string, error) {
reader := multipart.NewReader(bytes.NewReader(payload), boundary)
form, errRead := reader.ReadForm(openAICompatMultipartMemory)
if errRead != nil {
return nil, "", fmt.Errorf("read multipart form failed: %w", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
log.Errorf("openai compat executor: remove multipart form files error: %v", errRemove)
}
}()
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if model != "" {
if errWrite := writer.WriteField("model", model); errWrite != nil {
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
}
}
if stream {
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
}
}
for key, values := range form.Value {
if key == "model" || key == "stream" {
continue
}
for _, value := range values {
if errWrite := writer.WriteField(key, value); errWrite != nil {
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
}
}
}
for key, files := range form.File {
for _, fileHeader := range files {
if fileHeader == nil {
continue
}
header := cloneOpenAICompatMIMEHeader(fileHeader.Header)
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/octet-stream")
}
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
}
src, errOpen := fileHeader.Open()
if errOpen != nil {
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
}
_, errCopy := io.Copy(part, src)
if errClose := src.Close(); errClose != nil {
log.Errorf("openai compat executor: close upload file error: %v", errClose)
if errCopy == nil {
errCopy = errClose
}
}
if errCopy != nil {
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
}
}
}
if errClose := writer.Close(); errClose != nil {
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
}
return body.Bytes(), writer.FormDataContentType(), nil
}
func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (baseURL, apiKey string) {
if auth == nil {
return "", ""

View File

@@ -1,10 +1,14 @@
package executor
import (
"bytes"
"context"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
@@ -102,6 +106,265 @@ func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T)
}
}
func TestOpenAICompatExecutorImagesGenerationsPassthrough(t *testing.T) {
var gotPath string
var gotBody []byte
var gotContentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotContentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}],"usage":{"total_tokens":1}}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: []byte(`{"model":"compat-image","prompt":"draw"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: false,
Headers: http.Header{
"Content-Type": []string{"application/json"},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
},
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/v1/images/generations" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
}
if gotContentType != "application/json" {
t.Fatalf("content type = %q, want application/json", gotContentType)
}
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
}
if got := gjson.GetBytes(resp.Payload, "data.0.b64_json").String(); got != "AA==" {
t.Fatalf("response payload = %s", string(resp.Payload))
}
}
func TestOpenAICompatExecutorImagesGenerationsStreamsUpstream(t *testing.T) {
var gotPath string
var gotBody []byte
var gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAccept = r.Header.Get("Accept")
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: image_generation.partial\ndata: {\"type\":\"image_generation.partial\"}\n\n"))
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
streamResult, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: []byte(`{"model":"compat-image","prompt":"draw","stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: true,
Headers: http.Header{
"Content-Type": []string{"application/json"},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
},
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var streamed bytes.Buffer
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("stream chunk error: %v", chunk.Err)
}
streamed.Write(chunk.Payload)
}
if gotPath != "/v1/images/generations" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
}
if gotAccept != "text/event-stream" {
t.Fatalf("accept = %q, want text/event-stream", gotAccept)
}
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
}
if !gjson.GetBytes(gotBody, "stream").Bool() {
t.Fatalf("stream flag missing from upstream body: %s", string(gotBody))
}
if !strings.Contains(streamed.String(), "event: image_generation.partial") || !strings.Contains(streamed.String(), "data: [DONE]") {
t.Fatalf("streamed body = %q", streamed.String())
}
}
func TestOpenAICompatExecutorImagesEditsMultipartRewritesModel(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
t.Fatalf("write prompt field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
header.Set("Content-Type", "image/png")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
contentType := writer.FormDataContentType()
var gotPath string
var gotModel string
var gotPrompt string
var gotFile string
var gotFileContentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if errParse := r.ParseMultipartForm(32 << 20); errParse != nil {
t.Fatalf("parse multipart form: %v", errParse)
}
gotModel = r.FormValue("model")
gotPrompt = r.FormValue("prompt")
file, fileHeader, errFile := r.FormFile("image")
if errFile != nil {
t.Fatalf("read image file: %v", errFile)
}
gotFileContentType = fileHeader.Header.Get("Content-Type")
data, errRead := io.ReadAll(file)
if errClose := file.Close(); errClose != nil {
t.Fatalf("close image file: %v", errClose)
}
if errRead != nil {
t.Fatalf("read image file: %v", errRead)
}
gotFile = string(data)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}]}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: body.Bytes(),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: false,
Headers: http.Header{
"Content-Type": []string{contentType},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/edits",
},
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/v1/images/edits" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/edits")
}
if gotModel != "upstream-image" {
t.Fatalf("model = %q, want upstream-image", gotModel)
}
if gotPrompt != "edit" {
t.Fatalf("prompt = %q, want edit", gotPrompt)
}
if gotFile != "png-data" {
t.Fatalf("file = %q, want png-data", gotFile)
}
if gotFileContentType != "image/png" {
t.Fatalf("file content type = %q, want image/png", gotFileContentType)
}
}
func TestRewriteOpenAICompatImagesMultipartPayloadPreservesStreamAndFileContentType(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
t.Fatalf("write stream field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.webp"))
header.Set("Content-Type", "image/webp")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("webp-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
out, contentType, err := prepareOpenAICompatImagesPayload(body.Bytes(), "upstream-image", writer.FormDataContentType(), true)
if err != nil {
t.Fatalf("prepareOpenAICompatImagesPayload error: %v", err)
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil {
t.Fatalf("parse content type: %v", errParse)
}
if mediaType != "multipart/form-data" {
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
}
reader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
form, errRead := reader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read rewritten form: %v", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
t.Fatalf("remove form files: %v", errRemove)
}
}()
if got := form.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
t.Fatalf("model values = %#v, want upstream-image", got)
}
if got := form.Value["stream"]; len(got) != 1 || got[0] != "true" {
t.Fatalf("stream values = %#v, want true", got)
}
if got := form.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/webp" {
t.Fatalf("image headers = %#v, want image/webp", got)
}
}
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

View File

@@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sort"
"strings"
@@ -20,7 +21,7 @@ func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
out(strings.ToLower(name) + "|" + strings.ToLower(alias) + "|" + fmt.Sprintf("image=%t", model.Image))
}
})
return hashJoined(keys)

View File

@@ -25,6 +25,17 @@ func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
}
}
func TestComputeOpenAICompatModelsHash_IncludesImageFlag(t *testing.T) {
textModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image"}})
imageModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image", Image: true}})
if textModel == "" || imageModel == "" {
t.Fatal("hashes should not be empty")
}
if textModel == imageModel {
t.Fatal("hash should change when image flag changes")
}
}
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
a := []config.OpenAICompatibilityModel{
{Name: "gpt-4", Alias: "gpt4"},

View File

@@ -153,7 +153,7 @@ func openAICompatSignature(entry config.OpenAICompatibility) string {
if name == "" && alias == "" {
continue
}
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias)+"|"+fmt.Sprintf("image=%t", model.Image))
}
if len(models) > 0 {
sort.Strings(models)

View File

@@ -535,7 +535,16 @@ func appendAPIResponse(c *gin.Context, data []byte) {
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageWithAuthManager executes an OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
return nil, nil, errMsg
}
@@ -632,7 +641,16 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// This path is the only supported execution route.
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageStreamWithAuthManager executes a streaming OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
@@ -848,6 +866,10 @@ func statusFromError(err error) int {
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
return h.getRequestDetailsWithOptions(modelName, false)
}
func (h *BaseAPIHandler) getRequestDetailsWithOptions(modelName string, allowImageModel bool) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
@@ -872,10 +894,10 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
if strings.EqualFold(baseModel, "gpt-image-2") {
if strings.EqualFold(routeModelBaseName(baseModel), "gpt-image-2") && !allowImageModel {
return nil, "", &interfaces.ErrorMessage{
StatusCode: http.StatusServiceUnavailable,
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", baseModel),
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", routeModelBaseName(baseModel)),
}
}
@@ -902,6 +924,14 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
return providers, resolvedModelName, nil
}
func routeModelBaseName(model string) string {
model = strings.TrimSpace(model)
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
return strings.TrimSpace(model[idx+1:])
}
return model
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil

View File

@@ -104,6 +104,9 @@ func applyCodexClientModelMetadata(entry map[string]any, id string, model map[st
if info.ContextLength > 0 {
contextWindow = info.ContextLength
}
if info.Type == registry.OpenAIImageModelType {
entry["visibility"] = "hide"
}
applyCodexClientThinkingMetadata(entry, info.Thinking)
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -143,7 +145,20 @@ func isSupportedImagesModel(model string) bool {
if baseModel == defaultImagesToolModel {
return true
}
return isXAIImagesModel(model)
return isXAIImagesModel(model) || isOpenAICompatImagesModel(model)
}
func isDefaultImagesToolModel(model string) bool {
return imagesModelBase(model) == defaultImagesToolModel
}
func isOpenAICompatImagesModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
info := registry.LookupModelInfo(model)
return info != nil && info.Type == registry.OpenAIImageModelType
}
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
@@ -153,7 +168,7 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, %s, or a configured openai-compatibility image model.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Type: "invalid_request_error",
},
})
@@ -376,6 +391,90 @@ func multipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
return "data:" + mediaType + ";base64," + b64, nil
}
func buildOpenAICompatImagesJSONRequest(rawJSON []byte, imageModel string, stream bool) []byte {
payload := rawJSON
if model := strings.TrimSpace(imageModel); model != "" {
payload, _ = sjson.SetBytes(payload, "model", model)
}
if stream {
payload, _ = sjson.SetBytes(payload, "stream", true)
} else {
payload, _ = sjson.DeleteBytes(payload, "stream")
}
return payload
}
func cloneMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
dst := make(textproto.MIMEHeader, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func buildOpenAICompatImagesMultipartRequest(form *multipart.Form, imageModel string, stream bool) ([]byte, string, error) {
if form == nil {
return nil, "", fmt.Errorf("multipart form is nil")
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", imageModel); errWrite != nil {
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
}
if stream {
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
}
}
for key, values := range form.Value {
if key == "model" || key == "stream" {
continue
}
for _, value := range values {
if errWrite := writer.WriteField(key, value); errWrite != nil {
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
}
}
}
for key, files := range form.File {
for _, fileHeader := range files {
if fileHeader == nil {
continue
}
header := cloneMIMEHeader(fileHeader.Header)
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/octet-stream")
}
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
}
src, errOpen := fileHeader.Open()
if errOpen != nil {
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
}
_, errCopy := io.Copy(part, src)
if errClose := src.Close(); errClose != nil {
log.Errorf("openai images: close upload file error: %v", errClose)
if errCopy == nil {
errCopy = errClose
}
}
if errCopy != nil {
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
}
}
}
if errClose := writer.Close(); errClose != nil {
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
}
return body.Bytes(), writer.FormDataContentType(), nil
}
func parseIntField(raw string, fallback int64) int64 {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -454,11 +553,21 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_generation", stream)
return
}
tool := []byte(`{"type":"image_generation","action":"generate"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -589,6 +698,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
}
stream := parseBoolField(c.PostForm("stream"), false)
if isDefaultImagesToolModel(imageModel) {
imageReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
@@ -598,6 +722,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var maskDataURL *string
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
@@ -701,6 +840,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
images := collectXAIImagesFromJSON(rawJSON)
if len(images) == 0 {
@@ -717,6 +861,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var images []string
imagesResult := gjson.GetBytes(rawJSON, "images")
@@ -904,14 +1053,247 @@ func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, respon
h.collectXAIImages(c, xaiReq, responseFormat)
}
func (h *OpenAIAPIHandler) handleOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string, responseFormat string, streamPrefix string, stream bool) {
if stream {
h.streamOpenAICompatImages(c, compatReq, imageModel)
return
}
h.collectImagesWithModel(c, compatReq, imageModel, responseFormat)
}
func (h *OpenAIAPIHandler) handleRoutedImages(c *gin.Context, imageReq []byte, imageModel string, stream bool) {
if stream {
h.streamRoutedImages(c, imageReq, imageModel)
return
}
h.collectRoutedImages(c, imageReq, imageModel)
}
func (h *OpenAIAPIHandler) collectRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(imageModel)
resp, upstreamHeaders, errMsg := h.ExecuteImageWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.forwardRawImageStream(cliCtx, c, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Context, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
emitError := func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case <-ctx.Done():
cancel(ctx.Err())
return
case errMsg, ok := <-errs:
if ok && errMsg != nil {
emitError(errMsg)
cancel(errMsg.Error)
return
}
errs = nil
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
_, _ = c.Writer.Write(chunk)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
}
}
func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.ForwardStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, handlers.StreamForwardOptions{
WriteChunk: func(next []byte) {
_, _ = c.Writer.Write(next)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
},
})
return
}
}
}
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.collectImagesWithModel(c, xaiReq, model, responseFormat)
}
func (h *OpenAIAPIHandler) collectImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
@@ -937,6 +1319,11 @@ func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, respo
}
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.streamImagesWithModel(c, xaiReq, model, responseFormat, streamPrefix)
}
func (h *OpenAIAPIHandler) streamImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string, streamPrefix string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
@@ -949,8 +1336,8 @@ func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, respon
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {

View File

@@ -3,14 +3,17 @@ package openai
import (
"bytes"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
"github.com/tidwall/gjson"
@@ -40,7 +43,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
}
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", " + xaiImagesQualityModel + ", or a configured openai-compatibility image model."
if message != expectedMessage {
t.Fatalf("error message = %q, want %q", message, expectedMessage)
}
@@ -63,6 +66,25 @@ func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
}
}
func TestImagesModelValidationAllowsOpenAICompatImageModels(t *testing.T) {
modelRegistry := registry.GetGlobalRegistry()
clientID := "test-openai-compat-image-model-validation"
modelRegistry.RegisterClient(clientID, "openai-compatibility", []*registry.ModelInfo{
{ID: "compat-image-model", Object: "model", OwnedBy: "compat", Type: registry.OpenAIImageModelType},
{ID: "compat-chat-model", Object: "model", OwnedBy: "compat", Type: "openai-compatibility"},
})
t.Cleanup(func() {
modelRegistry.UnregisterClient(clientID)
})
if !isSupportedImagesModel("compat-image-model") {
t.Fatal("expected configured openai-compatibility image model to be supported")
}
if isSupportedImagesModel("compat-chat-model") {
t.Fatal("expected non-image openai-compatibility model to be rejected")
}
}
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
@@ -122,6 +144,100 @@ func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
}
}
func TestBuildOpenAICompatImagesJSONRequestPreservesStreamForStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":false}`), "upstream-image", true)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if !gjson.GetBytes(req, "stream").Bool() {
t.Fatalf("stream flag missing: %s", string(req))
}
}
func TestBuildOpenAICompatImagesJSONRequestDropsStreamForNonStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":true}`), "upstream-image", false)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if gjson.GetBytes(req, "stream").Exists() {
t.Fatalf("stream flag should be removed from non-streaming request: %s", string(req))
}
}
func TestBuildOpenAICompatImagesMultipartRequestPreservesStreamAndFileContentType(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
t.Fatalf("write stream field: %v", errWrite)
}
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
t.Fatalf("write prompt field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
header.Set("Content-Type", "image/png")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary())
form, errRead := reader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read source form: %v", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
t.Fatalf("remove source form files: %v", errRemove)
}
}()
out, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, "upstream-image", true)
if errBuild != nil {
t.Fatalf("buildOpenAICompatImagesMultipartRequest error: %v", errBuild)
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil {
t.Fatalf("parse content type: %v", errParse)
}
if mediaType != "multipart/form-data" {
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
}
rewrittenReader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
rewrittenForm, errRead := rewrittenReader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read rewritten form: %v", errRead)
}
defer func() {
if errRemove := rewrittenForm.RemoveAll(); errRemove != nil {
t.Fatalf("remove rewritten form files: %v", errRemove)
}
}()
if got := rewrittenForm.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
t.Fatalf("model values = %#v, want upstream-image", got)
}
if got := rewrittenForm.Value["stream"]; len(got) != 1 || got[0] != "true" {
t.Fatalf("stream values = %#v, want true", got)
}
if got := rewrittenForm.Value["prompt"]; len(got) != 1 || got[0] != "edit" {
t.Fatalf("prompt values = %#v, want edit", got)
}
if got := rewrittenForm.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/png" {
t.Fatalf("image headers = %#v, want image/png", got)
}
}
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)

View File

@@ -1208,30 +1208,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
if strings.EqualFold(compat.Name, compatName) {
isCompatAuth = true
// Convert compatibility models to registry models
ms := make([]*ModelInfo, 0, len(compat.Models))
for j := range compat.Models {
m := compat.Models[j]
// Use alias as model ID, fallback to name if alias is empty
modelID := m.Alias
if modelID == "" {
modelID = m.Name
}
thinking := m.Thinking
if thinking == nil {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
ms = append(ms, &ModelInfo{
ID: modelID,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
ms := buildOpenAICompatibilityConfigModels(compat)
// Register and return
if len(ms) > 0 {
if providerKey == "" {
@@ -1578,6 +1555,43 @@ type modelEntry interface {
GetAlias() string
}
func buildOpenAICompatibilityConfigModels(compat *config.OpenAICompatibility) []*ModelInfo {
if compat == nil || len(compat.Models) == 0 {
return nil
}
now := time.Now().Unix()
models := make([]*ModelInfo, 0, len(compat.Models))
for i := range compat.Models {
model := compat.Models[i]
modelID := strings.TrimSpace(model.Alias)
if modelID == "" {
modelID = strings.TrimSpace(model.Name)
}
if modelID == "" {
continue
}
modelType := "openai-compatibility"
if model.Image {
modelType = registry.OpenAIImageModelType
}
thinking := model.Thinking
if thinking == nil && !model.Image {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
Created: now,
OwnedBy: compat.Name,
Type: modelType,
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
return models
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil

View File

@@ -4,6 +4,7 @@ import (
"strings"
"testing"
internalregistry "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
@@ -63,3 +64,71 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T
t.Fatal("expected global excluded model to be present when attribute override is set")
}
}
func TestRegisterModelsForAuth_OpenAICompatibilityImageModelType(t *testing.T) {
service := &Service{
cfg: &config.Config{
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "images",
BaseURL: "https://example.com/v1",
Models: []config.OpenAICompatibilityModel{
{Name: "upstream-image", Alias: "compat-image", Image: true},
{Name: "upstream-chat", Alias: "compat-chat"},
},
},
},
},
}
auth := &coreauth.Auth{
ID: "auth-openai-compat-image",
Provider: "openai-compatibility",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "api_key",
"compat_name": "images",
"provider_key": "images",
},
}
modelRegistry := internalregistry.GetGlobalRegistry()
modelRegistry.UnregisterClient(auth.ID)
t.Cleanup(func() {
modelRegistry.UnregisterClient(auth.ID)
})
service.registerModelsForAuth(auth)
models := modelRegistry.GetModelsForClient(auth.ID)
var imageModel *internalregistry.ModelInfo
var chatModel *internalregistry.ModelInfo
for _, model := range models {
if model == nil {
continue
}
switch strings.TrimSpace(model.ID) {
case "compat-image":
imageModel = model
case "compat-chat":
chatModel = model
}
}
if imageModel == nil {
t.Fatal("expected compat-image to be registered")
}
if imageModel.Type != internalregistry.OpenAIImageModelType {
t.Fatalf("image model type = %q, want %q", imageModel.Type, internalregistry.OpenAIImageModelType)
}
if imageModel.Thinking != nil {
t.Fatalf("image model thinking = %+v, want nil", imageModel.Thinking)
}
if chatModel == nil {
t.Fatal("expected compat-chat to be registered")
}
if chatModel.Type != "openai-compatibility" {
t.Fatalf("chat model type = %q, want openai-compatibility", chatModel.Type)
}
if chatModel.Thinking == nil {
t.Fatal("expected chat model to keep default thinking support")
}
}