feat(executor): add tests for preserving key order in cache control functions

Added comprehensive tests to ensure key order is maintained when modifying payloads in `normalizeCacheControlTTL` and `enforceCacheControlLimit` functions. Removed unused helper functions and refactored implementations for better readability and efficiency.
This commit is contained in:
Luis Pater
2026-04-05 17:58:13 +08:00
parent ada8e2905e
commit 22a1a24cf5
2 changed files with 233 additions and 244 deletions

View File

@@ -8,7 +8,6 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -1463,182 +1462,6 @@ func countCacheControls(payload []byte) int {
return count
}
func parsePayloadObject(payload []byte) (map[string]any, bool) {
if len(payload) == 0 {
return nil, false
}
var root map[string]any
if err := json.Unmarshal(payload, &root); err != nil {
return nil, false
}
return root, true
}
func marshalPayloadObject(original []byte, root map[string]any) []byte {
if root == nil {
return original
}
out, err := json.Marshal(root)
if err != nil {
return original
}
return out
}
func asObject(v any) (map[string]any, bool) {
obj, ok := v.(map[string]any)
return obj, ok
}
func asArray(v any) ([]any, bool) {
arr, ok := v.([]any)
return arr, ok
}
func countCacheControlsMap(root map[string]any) int {
count := 0
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if tools, ok := asArray(root["tools"]); ok {
for _, item := range tools {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
}
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
last := -1
for idx, item := range arr {
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
last = idx
}
}
return last
}
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
for idx, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
delete(obj, "cache_control")
*excess--
}
}
}
func stripAllCacheControl(arr []any, excess *int) {
for _, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
func stripMessageCacheControl(messages []any, excess *int) {
for _, msg := range messages {
if *excess <= 0 {
return
}
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
@@ -1651,58 +1474,75 @@ func stripMessageCacheControl(messages []any, excess *int) {
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
original := payload
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
processBlock := func(path string, obj gjson.Result) {
cc := obj.Get("cache_control")
if !cc.Exists() {
return
}
if !cc.IsObject() {
seen5m = true
return
}
ttl := cc.Get("ttl")
if ttl.Type != gjson.String || ttl.String() != "1h" {
seen5m = true
return
}
if !seen5m {
return
}
ttlPath := path + ".cache_control.ttl"
updated, errDel := sjson.DeleteBytes(payload, ttlPath)
if errDel != nil {
return
}
payload = updated
modified = true
}
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item)
return true
})
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item)
return true
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
content := msg.Get("content")
if !content.IsArray() {
return true
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item)
return true
})
return true
})
}
if !modified {
return payload
return original
}
return marshalPayloadObject(payload, root)
return payload
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
@@ -1722,64 +1562,166 @@ func normalizeCacheControlTTL(payload []byte) []byte {
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
total := countCacheControlsMap(root)
total := countCacheControls(payload)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
var system []any
if arr, ok := asArray(root["system"]); ok {
system = arr
}
var tools []any
if arr, ok := asArray(root["tools"]); ok {
tools = arr
}
var messages []any
if arr, ok := asArray(root["messages"]); ok {
messages = arr
}
if len(system) > 0 {
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
lastIdx := -1
system.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
lastIdx := -1
tools.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(messages) > 0 {
stripMessageCacheControl(messages, &excess)
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
if excess <= 0 {
return false
}
content := msg.Get("content")
if !content.IsArray() {
return true
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(system) > 0 {
stripAllCacheControl(system, &excess)
system = gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripAllCacheControl(tools, &excess)
tools = gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
return marshalPayloadObject(payload, root)
return payload
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.