diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 21df454d3..4046c8ea0 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -13,6 +13,7 @@ import ( "strings" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" @@ -135,6 +136,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r requestPath := helps.PayloadRequestPath(opts) body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers) body, _ = sjson.SetBytes(body, "model", baseModel) + body = capGeminiMaxOutputTokens(body, baseModel) action := "generateContent" if req.Metadata != nil { @@ -243,6 +245,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A requestPath := helps.PayloadRequestPath(opts) body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers) body, _ = sjson.SetBytes(body, "model", baseModel) + body = capGeminiMaxOutputTokens(body, baseModel) baseURL := resolveGeminiBaseURL(auth) url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent") @@ -527,6 +530,26 @@ func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) { util.ApplyCustomHeadersFromAttrs(req, attrs) } +func capGeminiMaxOutputTokens(body []byte, modelName string) []byte { + maxOut := gjson.GetBytes(body, "generationConfig.maxOutputTokens") + if !maxOut.Exists() || maxOut.Type != gjson.Number { + return body + } + modelInfo := registry.LookupModelInfo(modelName, "gemini") + if modelInfo == nil { + return body + } + limit := modelInfo.OutputTokenLimit + if limit <= 0 { + limit = modelInfo.MaxCompletionTokens + } + if limit <= 0 || maxOut.Int() <= int64(limit) { + return body + } + body, _ = sjson.SetBytes(body, "generationConfig.maxOutputTokens", limit) + return body +} + func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte { if modelName == "gemini-2.5-flash-image-preview" { aspectRatioResult := gjson.GetBytes(rawJSON, "generationConfig.imageConfig.aspectRatio") diff --git a/internal/runtime/executor/gemini_executor_test.go b/internal/runtime/executor/gemini_executor_test.go new file mode 100644 index 000000000..fbcd0d55d --- /dev/null +++ b/internal/runtime/executor/gemini_executor_test.go @@ -0,0 +1,90 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestCapGeminiMaxOutputTokensUsesOutputTokenLimit(t *testing.T) { + body := []byte(`{"generationConfig":{"maxOutputTokens":500000,"temperature":0.2},"contents":[]}`) + + out := capGeminiMaxOutputTokens(body, "gemini-3.1-pro-preview") + + if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != 65536 { + t.Fatalf("maxOutputTokens = %d, want 65536", got) + } + if got := gjson.GetBytes(out, "generationConfig.temperature").Float(); got != 0.2 { + t.Fatalf("temperature = %v, want 0.2", got) + } +} + +func TestCapGeminiMaxOutputTokensLeavesAllowedOrUnknown(t *testing.T) { + tests := []struct { + name string + model string + body []byte + want int64 + }{ + { + name: "allowed value", + model: "gemini-3.1-pro-preview", + body: []byte(`{"generationConfig":{"maxOutputTokens":64000}}`), + want: 64000, + }, + { + name: "unknown model", + model: "custom-gemini-model", + body: []byte(`{"generationConfig":{"maxOutputTokens":500000}}`), + want: 500000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := capGeminiMaxOutputTokens(tt.body, tt.model) + if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != tt.want { + t.Fatalf("maxOutputTokens = %d, want %d", got, tt.want) + } + }) + } +} + +func TestGeminiExecutorExecuteCapsMaxOutputTokensBeforeUpstream(t *testing.T) { + var upstreamMaxOutputTokens int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + upstreamMaxOutputTokens = gjson.GetBytes(body, "generationConfig.maxOutputTokens").Int() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`)) + })) + defer server.Close() + + exec := NewGeminiExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "test-key", + "base_url": server.URL, + }} + req := cliproxyexecutor.Request{ + Model: "gemini-3.1-pro-preview", + Payload: []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":500000}}`), + } + + if _, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatGemini}); err != nil { + t.Fatalf("Execute() error = %v", err) + } + if upstreamMaxOutputTokens != 65536 { + t.Fatalf("upstream maxOutputTokens = %d, want 65536", upstreamMaxOutputTokens) + } +}