mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-17 04:09:57 +08:00
- Updated `ApplyPayloadConfigWithRoot` to prioritize `disable-image-generation` filtering before applying payload rules. - Ensured payload overrides can explicitly re-enable `image_generation` when required. - Added unit tests to validate `image_generation` restoration through overrides.
455 lines
12 KiB
Go
455 lines
12 KiB
Go
package helps
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
|
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
|
// against the original payload when provided. requestedModel carries the client-visible
|
|
// model name before alias resolution so payload rules can target aliases precisely.
|
|
// requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates.
|
|
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string, requestPath string) []byte {
|
|
if cfg == nil || len(payload) == 0 {
|
|
return payload
|
|
}
|
|
out := payload
|
|
|
|
// Apply disable-image-generation filtering before payload rules so config payload
|
|
// overrides can explicitly re-enable image_generation when desired.
|
|
if cfg.DisableImageGeneration != config.DisableImageGenerationOff {
|
|
if cfg.DisableImageGeneration != config.DisableImageGenerationChat || !isImagesEndpointRequestPath(requestPath) {
|
|
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
|
out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation")
|
|
}
|
|
}
|
|
|
|
rules := cfg.Payload
|
|
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
|
|
if hasPayloadRules {
|
|
model = strings.TrimSpace(model)
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if model != "" || requestedModel != "" {
|
|
candidates := payloadModelCandidates(model, requestedModel)
|
|
source := original
|
|
if len(source) == 0 {
|
|
source = payload
|
|
}
|
|
appliedDefaults := make(map[string]struct{})
|
|
// Apply default rules: first write wins per field across all matching rules.
|
|
for i := range rules.Default {
|
|
rule := &rules.Default[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
if gjson.GetBytes(source, fullPath).Exists() {
|
|
continue
|
|
}
|
|
if _, ok := appliedDefaults[fullPath]; ok {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
appliedDefaults[fullPath] = struct{}{}
|
|
}
|
|
}
|
|
// Apply default raw rules: first write wins per field across all matching rules.
|
|
for i := range rules.DefaultRaw {
|
|
rule := &rules.DefaultRaw[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
if gjson.GetBytes(source, fullPath).Exists() {
|
|
continue
|
|
}
|
|
if _, ok := appliedDefaults[fullPath]; ok {
|
|
continue
|
|
}
|
|
rawValue, ok := payloadRawValue(value)
|
|
if !ok {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
appliedDefaults[fullPath] = struct{}{}
|
|
}
|
|
}
|
|
// Apply override rules: last write wins per field across all matching rules.
|
|
for i := range rules.Override {
|
|
rule := &rules.Override[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
// Apply override raw rules: last write wins per field across all matching rules.
|
|
for i := range rules.OverrideRaw {
|
|
rule := &rules.OverrideRaw[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
rawValue, ok := payloadRawValue(value)
|
|
if !ok {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
// Apply filter rules: remove matching paths from payload.
|
|
for i := range rules.Filter {
|
|
rule := &rules.Filter[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
|
continue
|
|
}
|
|
for _, path := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
|
if errDel != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func isImagesEndpointRequestPath(path string) bool {
|
|
path = strings.TrimSpace(path)
|
|
if path == "" {
|
|
return false
|
|
}
|
|
if path == "/v1/images/generations" || path == "/v1/images/edits" {
|
|
return true
|
|
}
|
|
// Be tolerant of prefix routers that may report a longer matched route.
|
|
if strings.HasSuffix(path, "/v1/images/generations") || strings.HasSuffix(path, "/v1/images/edits") {
|
|
return true
|
|
}
|
|
if strings.HasSuffix(path, "/images/generations") || strings.HasSuffix(path, "/images/edits") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
|
if len(rules) == 0 || len(models) == 0 {
|
|
return false
|
|
}
|
|
for _, model := range models {
|
|
for _, entry := range rules {
|
|
name := strings.TrimSpace(entry.Name)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
|
continue
|
|
}
|
|
if matchModelPattern(name, model) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadModelCandidates(model, requestedModel string) []string {
|
|
model = strings.TrimSpace(model)
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if model == "" && requestedModel == "" {
|
|
return nil
|
|
}
|
|
candidates := make([]string, 0, 3)
|
|
seen := make(map[string]struct{}, 3)
|
|
addCandidate := func(value string) {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return
|
|
}
|
|
key := strings.ToLower(value)
|
|
if _, ok := seen[key]; ok {
|
|
return
|
|
}
|
|
seen[key] = struct{}{}
|
|
candidates = append(candidates, value)
|
|
}
|
|
if model != "" {
|
|
addCandidate(model)
|
|
}
|
|
if requestedModel != "" {
|
|
parsed := thinking.ParseSuffix(requestedModel)
|
|
base := strings.TrimSpace(parsed.ModelName)
|
|
if base != "" {
|
|
addCandidate(base)
|
|
}
|
|
if parsed.HasSuffix {
|
|
addCandidate(requestedModel)
|
|
}
|
|
}
|
|
return candidates
|
|
}
|
|
|
|
// buildPayloadPath combines an optional root path with a relative parameter path.
|
|
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
|
// the parameter path is treated as relative to root.
|
|
func buildPayloadPath(root, path string) string {
|
|
r := strings.TrimSpace(root)
|
|
p := strings.TrimSpace(path)
|
|
if r == "" {
|
|
return p
|
|
}
|
|
if p == "" {
|
|
return r
|
|
}
|
|
if strings.HasPrefix(p, ".") {
|
|
p = p[1:]
|
|
}
|
|
return r + "." + p
|
|
}
|
|
|
|
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
|
if len(payload) == 0 {
|
|
return payload
|
|
}
|
|
toolType = strings.TrimSpace(toolType)
|
|
if toolType == "" {
|
|
return payload
|
|
}
|
|
toolsPath := buildPayloadPath(root, "tools")
|
|
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
|
|
}
|
|
|
|
func removeToolChoiceFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
|
if len(payload) == 0 {
|
|
return payload
|
|
}
|
|
toolType = strings.TrimSpace(toolType)
|
|
if toolType == "" {
|
|
return payload
|
|
}
|
|
toolChoicePath := buildPayloadPath(root, "tool_choice")
|
|
return removeToolChoiceFromPayload(payload, toolChoicePath, toolType)
|
|
}
|
|
|
|
func removeToolChoiceFromPayload(payload []byte, toolChoicePath string, toolType string) []byte {
|
|
choice := gjson.GetBytes(payload, toolChoicePath)
|
|
if !choice.Exists() {
|
|
return payload
|
|
}
|
|
if choice.Type == gjson.String {
|
|
if strings.EqualFold(strings.TrimSpace(choice.String()), toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
}
|
|
return payload
|
|
}
|
|
if choice.Type != gjson.JSON {
|
|
return payload
|
|
}
|
|
choiceType := strings.TrimSpace(choice.Get("type").String())
|
|
if strings.EqualFold(choiceType, toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
return payload
|
|
}
|
|
if strings.EqualFold(choiceType, "tool") {
|
|
name := strings.TrimSpace(choice.Get("name").String())
|
|
if strings.EqualFold(name, toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
}
|
|
}
|
|
return payload
|
|
}
|
|
|
|
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
|
|
tools := gjson.GetBytes(payload, toolsPath)
|
|
if !tools.Exists() || !tools.IsArray() {
|
|
return payload
|
|
}
|
|
removed := false
|
|
filtered := []byte(`[]`)
|
|
for _, tool := range tools.Array() {
|
|
if tool.Get("type").String() == toolType {
|
|
removed = true
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
filtered = updated
|
|
}
|
|
if !removed {
|
|
return payload
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
|
|
if errSet != nil {
|
|
return payload
|
|
}
|
|
return updated
|
|
}
|
|
|
|
func payloadRawValue(value any) ([]byte, bool) {
|
|
if value == nil {
|
|
return nil, false
|
|
}
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return []byte(typed), true
|
|
case []byte:
|
|
return typed, true
|
|
default:
|
|
raw, errMarshal := json.Marshal(typed)
|
|
if errMarshal != nil {
|
|
return nil, false
|
|
}
|
|
return raw, true
|
|
}
|
|
}
|
|
|
|
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
|
fallback = strings.TrimSpace(fallback)
|
|
if len(opts.Metadata) == 0 {
|
|
return fallback
|
|
}
|
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
|
if !ok || raw == nil {
|
|
return fallback
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
if strings.TrimSpace(v) == "" {
|
|
return fallback
|
|
}
|
|
return strings.TrimSpace(v)
|
|
case []byte:
|
|
if len(v) == 0 {
|
|
return fallback
|
|
}
|
|
trimmed := strings.TrimSpace(string(v))
|
|
if trimmed == "" {
|
|
return fallback
|
|
}
|
|
return trimmed
|
|
default:
|
|
return fallback
|
|
}
|
|
}
|
|
|
|
func PayloadRequestPath(opts cliproxyexecutor.Options) string {
|
|
if len(opts.Metadata) == 0 {
|
|
return ""
|
|
}
|
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestPathMetadataKey]
|
|
if !ok || raw == nil {
|
|
return ""
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
case []byte:
|
|
return strings.TrimSpace(string(v))
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
|
// Examples:
|
|
//
|
|
// "*-5" matches "gpt-5"
|
|
// "gpt-*" matches "gpt-5" and "gpt-4"
|
|
// "gemini-*-pro" matches "gemini-2.5-pro" and "gemini-3-pro".
|
|
func matchModelPattern(pattern, model string) bool {
|
|
pattern = strings.TrimSpace(pattern)
|
|
model = strings.TrimSpace(model)
|
|
if pattern == "" {
|
|
return false
|
|
}
|
|
if pattern == "*" {
|
|
return true
|
|
}
|
|
// Iterative glob-style matcher supporting only '*' wildcard.
|
|
pi, si := 0, 0
|
|
starIdx := -1
|
|
matchIdx := 0
|
|
for si < len(model) {
|
|
if pi < len(pattern) && (pattern[pi] == model[si]) {
|
|
pi++
|
|
si++
|
|
continue
|
|
}
|
|
if pi < len(pattern) && pattern[pi] == '*' {
|
|
starIdx = pi
|
|
matchIdx = si
|
|
pi++
|
|
continue
|
|
}
|
|
if starIdx != -1 {
|
|
pi = starIdx + 1
|
|
matchIdx++
|
|
si = matchIdx
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
for pi < len(pattern) && pattern[pi] == '*' {
|
|
pi++
|
|
}
|
|
return pi == len(pattern)
|
|
}
|