mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-05-31 03:41:59 +08:00
- Updated variable, function, and struct names from `FormProtocol` to `FromProtocol` for clarity. - Adjusted related payload matching and normalization logic. - Updated tests and examples to align with the new naming convention.
914 lines
22 KiB
Go
914 lines
22 KiB
Go
package helps
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
|
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
|
// against the original payload when provided. requestedModel carries the client-visible
|
|
// model name before alias resolution so payload rules can target aliases precisely.
|
|
// requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates.
|
|
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string, requestPath string) []byte {
|
|
return ApplyPayloadConfigWithRequest(cfg, model, protocol, "", root, payload, original, requestedModel, requestPath, nil)
|
|
}
|
|
|
|
// ApplyPayloadConfigWithRequest applies payload config using source protocol and request header gates.
|
|
func ApplyPayloadConfigWithRequest(cfg *config.Config, model, protocol, fromProtocol, root string, payload, original []byte, requestedModel string, requestPath string, headers http.Header) []byte {
|
|
if cfg == nil || len(payload) == 0 {
|
|
return payload
|
|
}
|
|
out := payload
|
|
|
|
// Apply disable-image-generation filtering before payload rules so config payload
|
|
// overrides can explicitly re-enable image_generation when desired.
|
|
if cfg.DisableImageGeneration != config.DisableImageGenerationOff {
|
|
if cfg.DisableImageGeneration != config.DisableImageGenerationChat || !isImagesEndpointRequestPath(requestPath) {
|
|
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
|
out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation")
|
|
}
|
|
}
|
|
|
|
rules := cfg.Payload
|
|
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
|
|
if hasPayloadRules {
|
|
model = strings.TrimSpace(model)
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if model != "" || requestedModel != "" {
|
|
candidates := payloadModelCandidates(model, requestedModel)
|
|
source := original
|
|
if len(source) == 0 {
|
|
source = payload
|
|
}
|
|
appliedDefaults := make(map[string]struct{})
|
|
// Apply default rules: first write wins per field across all matching rules.
|
|
for i := range rules.Default {
|
|
rule := &rules.Default[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, fromProtocol, headers, out, root, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
for _, resolvedPath := range resolvePayloadRulePaths(out, fullPath) {
|
|
if gjson.GetBytes(source, resolvedPath).Exists() {
|
|
continue
|
|
}
|
|
if _, ok := appliedDefaults[resolvedPath]; ok {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetBytes(out, resolvedPath, value)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
appliedDefaults[resolvedPath] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
// Apply default raw rules: first write wins per field across all matching rules.
|
|
for i := range rules.DefaultRaw {
|
|
rule := &rules.DefaultRaw[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, fromProtocol, headers, out, root, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
for _, resolvedPath := range resolvePayloadRulePaths(out, fullPath) {
|
|
if gjson.GetBytes(source, resolvedPath).Exists() {
|
|
continue
|
|
}
|
|
if _, ok := appliedDefaults[resolvedPath]; ok {
|
|
continue
|
|
}
|
|
rawValue, ok := payloadRawValue(value)
|
|
if !ok {
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(out, resolvedPath, rawValue)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
appliedDefaults[resolvedPath] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
// Apply override rules: last write wins per field across all matching rules.
|
|
for i := range rules.Override {
|
|
rule := &rules.Override[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, fromProtocol, headers, out, root, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
for _, resolvedPath := range resolvePayloadRulePaths(out, fullPath) {
|
|
updated, errSet := sjson.SetBytes(out, resolvedPath, value)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
}
|
|
// Apply override raw rules: last write wins per field across all matching rules.
|
|
for i := range rules.OverrideRaw {
|
|
rule := &rules.OverrideRaw[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, fromProtocol, headers, out, root, candidates) {
|
|
continue
|
|
}
|
|
for path, value := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
rawValue, ok := payloadRawValue(value)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for _, resolvedPath := range resolvePayloadRulePaths(out, fullPath) {
|
|
updated, errSet := sjson.SetRawBytes(out, resolvedPath, rawValue)
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
}
|
|
// Apply filter rules: remove matching paths from payload.
|
|
for i := range rules.Filter {
|
|
rule := &rules.Filter[i]
|
|
if !payloadModelRulesMatch(rule.Models, protocol, fromProtocol, headers, out, root, candidates) {
|
|
continue
|
|
}
|
|
for _, path := range rule.Params {
|
|
fullPath := buildPayloadPath(root, path)
|
|
if fullPath == "" {
|
|
continue
|
|
}
|
|
resolvedPaths := resolvePayloadRulePaths(out, fullPath)
|
|
for i := len(resolvedPaths) - 1; i >= 0; i-- {
|
|
resolvedPath := resolvedPaths[i]
|
|
updated, errDel := sjson.DeleteBytes(out, resolvedPath)
|
|
if errDel != nil {
|
|
continue
|
|
}
|
|
out = updated
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func isImagesEndpointRequestPath(path string) bool {
|
|
path = strings.TrimSpace(path)
|
|
if path == "" {
|
|
return false
|
|
}
|
|
if path == "/v1/images/generations" || path == "/v1/images/edits" {
|
|
return true
|
|
}
|
|
// Be tolerant of prefix routers that may report a longer matched route.
|
|
if strings.HasSuffix(path, "/v1/images/generations") || strings.HasSuffix(path, "/v1/images/edits") {
|
|
return true
|
|
}
|
|
if strings.HasSuffix(path, "/images/generations") || strings.HasSuffix(path, "/images/edits") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, fromProtocol string, headers http.Header, payload []byte, root string, models []string) bool {
|
|
if len(rules) == 0 || len(models) == 0 {
|
|
return false
|
|
}
|
|
for _, model := range models {
|
|
for _, entry := range rules {
|
|
name := strings.TrimSpace(entry.Name)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
|
continue
|
|
}
|
|
if !payloadFromProtocolMatches(entry.FromProtocol, fromProtocol) {
|
|
continue
|
|
}
|
|
if !payloadHeadersMatch(headers, entry.Headers) {
|
|
continue
|
|
}
|
|
if !matchModelPattern(name, model) {
|
|
continue
|
|
}
|
|
if payloadModelRuleConditionsMatch(payload, root, entry) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadModelRuleConditionsMatch(payload []byte, root string, rule config.PayloadModelRule) bool {
|
|
if !payloadMatchConditionsMatch(payload, root, rule.Match) {
|
|
return false
|
|
}
|
|
if !payloadNotMatchConditionsMatch(payload, root, rule.NotMatch) {
|
|
return false
|
|
}
|
|
if !payloadExistConditionsMatch(payload, root, rule.Exist) {
|
|
return false
|
|
}
|
|
if !payloadNotExistConditionsMatch(payload, root, rule.NotExist) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadMatchConditionsMatch(payload []byte, root string, conditions []map[string]any) bool {
|
|
for _, condition := range conditions {
|
|
for path, value := range condition {
|
|
if strings.TrimSpace(path) == "" {
|
|
continue
|
|
}
|
|
if !payloadPathMatchesValue(payload, buildPayloadPath(root, path), value) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadNotMatchConditionsMatch(payload []byte, root string, conditions []map[string]any) bool {
|
|
for _, condition := range conditions {
|
|
for path, value := range condition {
|
|
if strings.TrimSpace(path) == "" {
|
|
continue
|
|
}
|
|
if payloadPathMatchesValue(payload, buildPayloadPath(root, path), value) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadExistConditionsMatch(payload []byte, root string, paths []string) bool {
|
|
for _, path := range paths {
|
|
if strings.TrimSpace(path) == "" {
|
|
continue
|
|
}
|
|
if !payloadPathExists(payload, buildPayloadPath(root, path)) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadNotExistConditionsMatch(payload []byte, root string, paths []string) bool {
|
|
for _, path := range paths {
|
|
if strings.TrimSpace(path) == "" {
|
|
continue
|
|
}
|
|
if payloadPathExists(payload, buildPayloadPath(root, path)) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadPathMatchesValue(payload []byte, path string, value any) bool {
|
|
for _, resolvedPath := range resolvePayloadRulePaths(payload, path) {
|
|
result := gjson.GetBytes(payload, resolvedPath)
|
|
if !result.Exists() {
|
|
continue
|
|
}
|
|
if payloadResultEquals(result, value) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadPathExists(payload []byte, path string) bool {
|
|
for _, resolvedPath := range resolvePayloadRulePaths(payload, path) {
|
|
result := gjson.GetBytes(payload, resolvedPath)
|
|
if result.Exists() && result.Type != gjson.Null {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadResultEquals(result gjson.Result, value any) bool {
|
|
actual, ok := normalizedPayloadResult(result)
|
|
if !ok {
|
|
return false
|
|
}
|
|
expected, ok := normalizedPayloadValue(value)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return reflect.DeepEqual(actual, expected)
|
|
}
|
|
|
|
func normalizedPayloadResult(result gjson.Result) (any, bool) {
|
|
if !result.Exists() {
|
|
return nil, false
|
|
}
|
|
raw := strings.TrimSpace(result.Raw)
|
|
if raw == "" {
|
|
encoded, errMarshal := json.Marshal(result.Value())
|
|
if errMarshal != nil {
|
|
return nil, false
|
|
}
|
|
raw = string(encoded)
|
|
}
|
|
return normalizedPayloadJSON([]byte(raw))
|
|
}
|
|
|
|
func normalizedPayloadValue(value any) (any, bool) {
|
|
encoded, errMarshal := json.Marshal(value)
|
|
if errMarshal != nil {
|
|
return nil, false
|
|
}
|
|
return normalizedPayloadJSON(encoded)
|
|
}
|
|
|
|
func normalizedPayloadJSON(data []byte) (any, bool) {
|
|
if len(strings.TrimSpace(string(data))) == 0 {
|
|
return nil, false
|
|
}
|
|
var out any
|
|
if errUnmarshal := json.Unmarshal(data, &out); errUnmarshal != nil {
|
|
return nil, false
|
|
}
|
|
return out, true
|
|
}
|
|
|
|
func payloadFromProtocolMatches(pattern, fromProtocol string) bool {
|
|
pattern = normalizePayloadFromProtocol(pattern)
|
|
if pattern == "" {
|
|
return true
|
|
}
|
|
fromProtocol = normalizePayloadFromProtocol(fromProtocol)
|
|
if fromProtocol == "" {
|
|
return false
|
|
}
|
|
return strings.EqualFold(pattern, fromProtocol)
|
|
}
|
|
|
|
func normalizePayloadFromProtocol(protocol string) string {
|
|
protocol = strings.ToLower(strings.TrimSpace(protocol))
|
|
switch protocol {
|
|
case "openai-response", "openai-responses", "response":
|
|
return "responses"
|
|
case "gemini-cli":
|
|
return "gemini"
|
|
default:
|
|
return protocol
|
|
}
|
|
}
|
|
|
|
func payloadHeadersMatch(headers http.Header, rules map[string]string) bool {
|
|
if len(rules) == 0 {
|
|
return true
|
|
}
|
|
for key, pattern := range rules {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
values := payloadHeaderValues(headers, key)
|
|
if len(values) == 0 {
|
|
return false
|
|
}
|
|
matched := false
|
|
for _, value := range values {
|
|
if matchModelPattern(pattern, value) {
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if !matched {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func payloadHeaderValues(headers http.Header, key string) []string {
|
|
if headers == nil {
|
|
return nil
|
|
}
|
|
var values []string
|
|
for headerKey, headerValues := range headers {
|
|
if strings.EqualFold(headerKey, key) {
|
|
values = append(values, headerValues...)
|
|
}
|
|
}
|
|
return values
|
|
}
|
|
|
|
func payloadModelCandidates(model, requestedModel string) []string {
|
|
model = strings.TrimSpace(model)
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if model == "" && requestedModel == "" {
|
|
return nil
|
|
}
|
|
candidates := make([]string, 0, 3)
|
|
seen := make(map[string]struct{}, 3)
|
|
addCandidate := func(value string) {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return
|
|
}
|
|
key := strings.ToLower(value)
|
|
if _, ok := seen[key]; ok {
|
|
return
|
|
}
|
|
seen[key] = struct{}{}
|
|
candidates = append(candidates, value)
|
|
}
|
|
if model != "" {
|
|
addCandidate(model)
|
|
}
|
|
if requestedModel != "" {
|
|
parsed := thinking.ParseSuffix(requestedModel)
|
|
base := strings.TrimSpace(parsed.ModelName)
|
|
if base != "" {
|
|
addCandidate(base)
|
|
}
|
|
if parsed.HasSuffix {
|
|
addCandidate(requestedModel)
|
|
}
|
|
}
|
|
return candidates
|
|
}
|
|
|
|
// buildPayloadPath combines an optional root path with a relative parameter path.
|
|
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
|
// the parameter path is treated as relative to root.
|
|
func buildPayloadPath(root, path string) string {
|
|
r := strings.TrimSpace(root)
|
|
p := strings.TrimSpace(path)
|
|
if r == "" {
|
|
return p
|
|
}
|
|
if p == "" {
|
|
return r
|
|
}
|
|
if strings.HasPrefix(p, ".") {
|
|
p = p[1:]
|
|
}
|
|
return r + "." + p
|
|
}
|
|
|
|
func resolvePayloadRulePaths(payload []byte, path string) []string {
|
|
path = strings.TrimSpace(path)
|
|
if path == "" {
|
|
return nil
|
|
}
|
|
if !strings.Contains(path, "#(") {
|
|
return []string{path}
|
|
}
|
|
parts := splitPayloadRulePath(path)
|
|
if len(parts) == 0 {
|
|
return nil
|
|
}
|
|
paths := []string{""}
|
|
for _, part := range parts {
|
|
query, allMatches, ok := parsePayloadQueryPathPart(part)
|
|
if !ok {
|
|
for i := range paths {
|
|
paths[i] = appendPayloadPathPart(paths[i], part)
|
|
}
|
|
continue
|
|
}
|
|
nextPaths := make([]string, 0, len(paths))
|
|
for _, basePath := range paths {
|
|
array := payloadValueAtPath(payload, basePath)
|
|
if !array.Exists() || !array.IsArray() {
|
|
continue
|
|
}
|
|
for index, item := range array.Array() {
|
|
if !payloadQueryMatches(item, query) {
|
|
continue
|
|
}
|
|
nextPaths = append(nextPaths, appendPayloadPathPart(basePath, strconv.Itoa(index)))
|
|
if !allMatches {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
paths = nextPaths
|
|
if len(paths) == 0 {
|
|
return nil
|
|
}
|
|
}
|
|
return paths
|
|
}
|
|
|
|
func splitPayloadRulePath(path string) []string {
|
|
var parts []string
|
|
start := 0
|
|
depth := 0
|
|
var quote byte
|
|
escaped := false
|
|
for i := 0; i < len(path); i++ {
|
|
ch := path[i]
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if quote != 0 {
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '(' {
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == ')' {
|
|
if depth > 0 {
|
|
depth--
|
|
}
|
|
continue
|
|
}
|
|
if ch == '.' && depth == 0 {
|
|
parts = append(parts, path[start:i])
|
|
start = i + 1
|
|
}
|
|
}
|
|
parts = append(parts, path[start:])
|
|
return parts
|
|
}
|
|
|
|
func parsePayloadQueryPathPart(part string) (string, bool, bool) {
|
|
if !strings.HasPrefix(part, "#(") {
|
|
return "", false, false
|
|
}
|
|
closeIndex := findPayloadQueryClose(part)
|
|
if closeIndex < 0 {
|
|
return "", false, false
|
|
}
|
|
suffix := part[closeIndex+1:]
|
|
if suffix != "" && suffix != "#" {
|
|
return "", false, false
|
|
}
|
|
return strings.TrimSpace(part[2:closeIndex]), suffix == "#", true
|
|
}
|
|
|
|
func findPayloadQueryClose(part string) int {
|
|
var quote byte
|
|
escaped := false
|
|
depth := 1
|
|
for i := 2; i < len(part); i++ {
|
|
ch := part[i]
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if quote != 0 {
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '(' {
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == ')' {
|
|
depth--
|
|
if depth == 0 {
|
|
return i
|
|
}
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func appendPayloadPathPart(path, part string) string {
|
|
if path == "" {
|
|
return part
|
|
}
|
|
if part == "" {
|
|
return path
|
|
}
|
|
return path + "." + part
|
|
}
|
|
|
|
func payloadValueAtPath(payload []byte, path string) gjson.Result {
|
|
if path == "" {
|
|
return gjson.ParseBytes(payload)
|
|
}
|
|
return gjson.GetBytes(payload, path)
|
|
}
|
|
|
|
func payloadQueryMatches(item gjson.Result, query string) bool {
|
|
for _, orPart := range splitPayloadLogical(query, "||") {
|
|
if payloadQueryAndMatches(item, orPart) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func payloadQueryAndMatches(item gjson.Result, query string) bool {
|
|
parts := splitPayloadLogical(query, "&&")
|
|
if len(parts) == 0 {
|
|
return false
|
|
}
|
|
for _, part := range parts {
|
|
if !payloadQueryTermMatches(item, part) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func splitPayloadLogical(query, operator string) []string {
|
|
var parts []string
|
|
start := 0
|
|
var quote byte
|
|
escaped := false
|
|
for i := 0; i < len(query); i++ {
|
|
ch := query[i]
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if quote != 0 {
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if strings.HasPrefix(query[i:], operator) {
|
|
parts = append(parts, strings.TrimSpace(query[start:i]))
|
|
i += len(operator) - 1
|
|
start = i + 1
|
|
}
|
|
}
|
|
parts = append(parts, strings.TrimSpace(query[start:]))
|
|
return parts
|
|
}
|
|
|
|
func payloadQueryTermMatches(item gjson.Result, term string) bool {
|
|
term = strings.TrimSpace(term)
|
|
if term == "" || item.Raw == "" {
|
|
return false
|
|
}
|
|
wrapped := make([]byte, 0, len(item.Raw)+2)
|
|
wrapped = append(wrapped, '[')
|
|
wrapped = append(wrapped, item.Raw...)
|
|
wrapped = append(wrapped, ']')
|
|
return gjson.GetBytes(wrapped, "#("+term+")").Exists()
|
|
}
|
|
|
|
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
|
if len(payload) == 0 {
|
|
return payload
|
|
}
|
|
toolType = strings.TrimSpace(toolType)
|
|
if toolType == "" {
|
|
return payload
|
|
}
|
|
toolsPath := buildPayloadPath(root, "tools")
|
|
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
|
|
}
|
|
|
|
func removeToolChoiceFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
|
if len(payload) == 0 {
|
|
return payload
|
|
}
|
|
toolType = strings.TrimSpace(toolType)
|
|
if toolType == "" {
|
|
return payload
|
|
}
|
|
toolChoicePath := buildPayloadPath(root, "tool_choice")
|
|
return removeToolChoiceFromPayload(payload, toolChoicePath, toolType)
|
|
}
|
|
|
|
func removeToolChoiceFromPayload(payload []byte, toolChoicePath string, toolType string) []byte {
|
|
choice := gjson.GetBytes(payload, toolChoicePath)
|
|
if !choice.Exists() {
|
|
return payload
|
|
}
|
|
if choice.Type == gjson.String {
|
|
if strings.EqualFold(strings.TrimSpace(choice.String()), toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
}
|
|
return payload
|
|
}
|
|
if choice.Type != gjson.JSON {
|
|
return payload
|
|
}
|
|
choiceType := strings.TrimSpace(choice.Get("type").String())
|
|
if strings.EqualFold(choiceType, toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
return payload
|
|
}
|
|
if strings.EqualFold(choiceType, "tool") {
|
|
name := strings.TrimSpace(choice.Get("name").String())
|
|
if strings.EqualFold(name, toolType) {
|
|
updated, errDel := sjson.DeleteBytes(payload, toolChoicePath)
|
|
if errDel == nil {
|
|
return updated
|
|
}
|
|
}
|
|
}
|
|
return payload
|
|
}
|
|
|
|
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
|
|
tools := gjson.GetBytes(payload, toolsPath)
|
|
if !tools.Exists() || !tools.IsArray() {
|
|
return payload
|
|
}
|
|
removed := false
|
|
filtered := []byte(`[]`)
|
|
for _, tool := range tools.Array() {
|
|
if tool.Get("type").String() == toolType {
|
|
removed = true
|
|
continue
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
|
|
if errSet != nil {
|
|
continue
|
|
}
|
|
filtered = updated
|
|
}
|
|
if !removed {
|
|
return payload
|
|
}
|
|
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
|
|
if errSet != nil {
|
|
return payload
|
|
}
|
|
return updated
|
|
}
|
|
|
|
func payloadRawValue(value any) ([]byte, bool) {
|
|
if value == nil {
|
|
return nil, false
|
|
}
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return []byte(typed), true
|
|
case []byte:
|
|
return typed, true
|
|
default:
|
|
raw, errMarshal := json.Marshal(typed)
|
|
if errMarshal != nil {
|
|
return nil, false
|
|
}
|
|
return raw, true
|
|
}
|
|
}
|
|
|
|
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
|
fallback = strings.TrimSpace(fallback)
|
|
if len(opts.Metadata) == 0 {
|
|
return fallback
|
|
}
|
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
|
if !ok || raw == nil {
|
|
return fallback
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
if strings.TrimSpace(v) == "" {
|
|
return fallback
|
|
}
|
|
return strings.TrimSpace(v)
|
|
case []byte:
|
|
if len(v) == 0 {
|
|
return fallback
|
|
}
|
|
trimmed := strings.TrimSpace(string(v))
|
|
if trimmed == "" {
|
|
return fallback
|
|
}
|
|
return trimmed
|
|
default:
|
|
return fallback
|
|
}
|
|
}
|
|
|
|
func PayloadRequestPath(opts cliproxyexecutor.Options) string {
|
|
if len(opts.Metadata) == 0 {
|
|
return ""
|
|
}
|
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestPathMetadataKey]
|
|
if !ok || raw == nil {
|
|
return ""
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
case []byte:
|
|
return strings.TrimSpace(string(v))
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
|
// Examples:
|
|
//
|
|
// "*-5" matches "gpt-5"
|
|
// "gpt-*" matches "gpt-5" and "gpt-4"
|
|
// "gemini-*-pro" matches "gemini-2.5-pro" and "gemini-3-pro".
|
|
func matchModelPattern(pattern, model string) bool {
|
|
pattern = strings.TrimSpace(pattern)
|
|
model = strings.TrimSpace(model)
|
|
if pattern == "" {
|
|
return false
|
|
}
|
|
if pattern == "*" {
|
|
return true
|
|
}
|
|
// Iterative glob-style matcher supporting only '*' wildcard.
|
|
pi, si := 0, 0
|
|
starIdx := -1
|
|
matchIdx := 0
|
|
for si < len(model) {
|
|
if pi < len(pattern) && (pattern[pi] == model[si]) {
|
|
pi++
|
|
si++
|
|
continue
|
|
}
|
|
if pi < len(pattern) && pattern[pi] == '*' {
|
|
starIdx = pi
|
|
matchIdx = si
|
|
pi++
|
|
continue
|
|
}
|
|
if starIdx != -1 {
|
|
pi = starIdx + 1
|
|
matchIdx++
|
|
si = matchIdx
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
for pi < len(pattern) && pattern[pi] == '*' {
|
|
pi++
|
|
}
|
|
return pi == len(pattern)
|
|
}
|