diff --git a/internal/signature/claude.go b/internal/signature/claude.go new file mode 100644 index 000000000..4b3fbde25 --- /dev/null +++ b/internal/signature/claude.go @@ -0,0 +1,113 @@ +package signature + +import ( + "bytes" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// StripInvalidClaudeThinkingBlocks removes Claude thinking blocks whose +// signatures are empty or not valid Claude thinking signatures after stripping +// an optional cache prefix, unless the validation options allow an empty +// thinking placeholder. +func StripInvalidClaudeThinkingBlocks(payload []byte, opts ...ClaudeSignatureValidationOptions) []byte { + messages := gjson.GetBytes(payload, "messages") + if !messages.IsArray() { + return payload + } + opt := claudeSignatureValidationOptions(opts) + messageResults := messages.Array() + keptMessages := make([]string, 0, len(messageResults)) + modified := false + for _, msg := range messageResults { + content := msg.Get("content") + if !content.IsArray() { + keptMessages = append(keptMessages, msg.Raw) + continue + } + contentResults := content.Array() + keptParts := make([]string, 0, len(contentResults)) + stripped := false + for _, part := range contentResults { + if part.Get("type").String() == "thinking" && shouldStripClaudeThinkingBlock(part, opt) { + stripped = true + continue + } + keptParts = append(keptParts, part.Raw) + } + if stripped { + modified = true + updated, _ := sjson.SetRaw(msg.Raw, "content", "["+strings.Join(keptParts, ",")+"]") + keptMessages = append(keptMessages, updated) + continue + } + keptMessages = append(keptMessages, msg.Raw) + } + if !modified { + return payload + } + output, _ := sjson.SetRawBytes(payload, "messages", []byte("["+strings.Join(keptMessages, ",")+"]")) + return output +} + +// StripInvalidClaudeThinkingBlocksAndEmptyMessages also removes messages whose +// content becomes empty after invalid thinking blocks are removed. +func StripInvalidClaudeThinkingBlocksAndEmptyMessages(payload []byte, opts ...ClaudeSignatureValidationOptions) []byte { + stripped := StripInvalidClaudeThinkingBlocks(payload, opts...) + if bytes.Equal(stripped, payload) { + return payload + } + messages := gjson.GetBytes(stripped, "messages") + if !messages.IsArray() { + return stripped + } + kept := make([]string, 0, len(messages.Array())) + for _, message := range messages.Array() { + content := message.Get("content") + if content.IsArray() && len(content.Array()) == 0 { + continue + } + kept = append(kept, message.Raw) + } + stripped, _ = sjson.SetRawBytes(stripped, "messages", []byte("["+strings.Join(kept, ",")+"]")) + return stripped +} + +func shouldStripClaudeThinkingBlock(part gjson.Result, opt ClaudeSignatureValidationOptions) bool { + if opt.AllowEmptySignatureWithEmptyText && isEmptyClaudeThinkingPlaceholder(part) { + return false + } + return !IsValidClaudeThinkingSignature(part.Get("signature").String(), opt) +} + +func isEmptyClaudeThinkingPlaceholder(part gjson.Result) bool { + if strings.TrimSpace(part.Get("signature").String()) != "" { + return false + } + return strings.TrimSpace(claudeThinkingBlockText(part)) == "" +} + +func claudeThinkingBlockText(part gjson.Result) string { + if text := part.Get("text"); text.Exists() && text.Type == gjson.String { + return text.String() + } + + thinkingField := part.Get("thinking") + if !thinkingField.Exists() { + return "" + } + if thinkingField.Type == gjson.String { + return thinkingField.String() + } + if thinkingField.IsObject() { + if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String { + return inner.String() + } + if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String { + return inner.String() + } + } + return "" +} diff --git a/internal/signature/claude_messages_sanitize.go b/internal/signature/claude_messages_sanitize.go new file mode 100644 index 000000000..aec08879d --- /dev/null +++ b/internal/signature/claude_messages_sanitize.go @@ -0,0 +1,249 @@ +package signature + +import ( + "fmt" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type ClaudeMessagesSignatureSanitizeOptions struct { + TargetProvider SignatureProvider + TargetModel string + DropEmptyMessages bool + DropToolSignatures bool +} + +type SignatureSanitizeReport struct { + TargetProvider SignatureProvider + Preserved int + DroppedBlocks int + DroppedSignatures int + ReplacedSignatures int + Decisions []SignatureCompatibilityDecision +} + +// SanitizeClaudeMessagesSignaturesForModel removes or preserves Claude +// /v1/messages signed history according to the provider family implied by +// targetModel. +func SanitizeClaudeMessagesSignaturesForModel(payload []byte, targetModel string) ([]byte, SignatureSanitizeReport) { + return SanitizeClaudeMessagesSignaturesForTarget(payload, ClaudeMessagesSignatureSanitizeOptions{ + TargetProvider: SignatureProviderFromModelName(targetModel), + TargetModel: targetModel, + DropEmptyMessages: true, + }) +} + +// SanitizeClaudeMessagesSignaturesForTarget applies provider-aware signature +// compatibility rules to Claude /v1/messages history. Compatible thinking +// signatures are preserved. Incompatible thinking blocks are removed so a user +// can continue a conversation after switching between Claude, GPT/Codex, +// and Gemini models. +func SanitizeClaudeMessagesSignaturesForTarget(payload []byte, opts ClaudeMessagesSignatureSanitizeOptions) ([]byte, SignatureSanitizeReport) { + targetProvider := normalizeSignatureTargetProvider(opts.TargetProvider) + if targetProvider == SignatureProviderUnknown && opts.TargetModel != "" { + targetProvider = SignatureProviderFromModelName(opts.TargetModel) + } + report := SignatureSanitizeReport{TargetProvider: targetProvider} + + messages := gjson.GetBytes(payload, "messages") + if !messages.IsArray() { + return payload, report + } + + messageResults := messages.Array() + keptMessages := make([]string, 0, len(messageResults)) + modified := false + + for i, message := range messageResults { + content := message.Get("content") + if !content.IsArray() { + keptMessages = append(keptMessages, message.Raw) + continue + } + + contentResults := content.Array() + keptParts := make([]string, 0, len(contentResults)) + messageModified := false + + for j, part := range contentResults { + partType := part.Get("type").String() + if partType == "tool_use" { + if opts.DropToolSignatures { + updatedPart, changed := stripClaudeToolUseSignatureFields(part) + if changed { + messageModified = true + report.DroppedSignatures++ + } + keptParts = append(keptParts, updatedPart) + continue + } + updatedPart, changed, decisions := sanitizeClaudeToolUseSignature(part, targetProvider, i, j) + report.Decisions = append(report.Decisions, decisions...) + if changed { + messageModified = true + } + for _, decision := range decisions { + switch decision.Action { + case SignatureActionPreserve: + report.Preserved++ + case SignatureActionReplaceWithGeminiBypass: + report.ReplacedSignatures++ + default: + report.DroppedSignatures++ + } + } + keptParts = append(keptParts, updatedPart) + continue + } + + if partType != "thinking" { + keptParts = append(keptParts, part.Raw) + continue + } + + if targetProvider == SignatureProviderClaude && isEmptyClaudeThinkingPlaceholder(part) { + keptParts = append(keptParts, part.Raw) + continue + } + + rawSignature := part.Get("signature").String() + decision := DecideSignatureCompatibility(targetProvider, rawSignature, SignatureBlockKindClaudeThinking) + decision.Reason = fmt.Sprintf("messages[%d].content[%d]: %s", i, j, decision.Reason) + report.Decisions = append(report.Decisions, decision) + + switch decision.Action { + case SignatureActionPreserve: + report.Preserved++ + if decision.NormalizedSignature != "" && decision.NormalizedSignature != rawSignature { + updated, _ := sjson.Set(part.Raw, "signature", decision.NormalizedSignature) + keptParts = append(keptParts, updated) + messageModified = true + continue + } + keptParts = append(keptParts, part.Raw) + case SignatureActionReplaceWithGeminiBypass: + report.ReplacedSignatures++ + updated, _ := sjson.Set(part.Raw, "signature", decision.ReplacementSignature) + keptParts = append(keptParts, updated) + messageModified = true + case SignatureActionDropSignature: + report.DroppedSignatures++ + updated, _ := sjson.Delete(part.Raw, "signature") + keptParts = append(keptParts, updated) + messageModified = true + default: + report.DroppedBlocks++ + messageModified = true + } + } + + if messageModified { + modified = true + if len(keptParts) == 0 && opts.DropEmptyMessages { + continue + } + updated, _ := sjson.SetRaw(message.Raw, "content", "["+strings.Join(keptParts, ",")+"]") + keptMessages = append(keptMessages, updated) + continue + } + + keptMessages = append(keptMessages, message.Raw) + } + + if !modified { + return payload, report + } + output, _ := sjson.SetRawBytes(payload, "messages", []byte("["+strings.Join(keptMessages, ",")+"]")) + return output, report +} + +func stripClaudeToolUseSignatureFields(part gjson.Result) (string, bool) { + updated := part.Raw + changed := false + for _, sigPath := range claudeToolUseSignaturePaths() { + if !gjson.Get(updated, sigPath).Exists() { + continue + } + updated, _ = sjson.Delete(updated, sigPath) + changed = true + } + if cleaned, ok := deleteEmptyJSONObjectPath(updated, "extra_content.google"); ok { + updated = cleaned + changed = true + } + if cleaned, ok := deleteEmptyJSONObjectPath(updated, "extra_content"); ok { + updated = cleaned + changed = true + } + return updated, changed +} + +func sanitizeClaudeToolUseSignature(part gjson.Result, targetProvider SignatureProvider, messageIdx, partIdx int) (string, bool, []SignatureCompatibilityDecision) { + updated := part.Raw + changed := false + var decisions []SignatureCompatibilityDecision + + for _, sigPath := range claudeToolUseSignaturePaths() { + sigResult := part.Get(sigPath) + if !sigResult.Exists() { + continue + } + + blockKind := SignatureBlockKindGeminiFunctionCall + if targetProvider == SignatureProviderClaude { + blockKind = SignatureBlockKindClaudeThinking + } else if targetProvider == SignatureProviderGPT { + blockKind = SignatureBlockKindGPTReasoning + } + decision := DecideSignatureCompatibility(targetProvider, sigResult.String(), blockKind) + decision.Reason = fmt.Sprintf("messages[%d].content[%d].%s: %s", messageIdx, partIdx, sigPath, decision.Reason) + decisions = append(decisions, decision) + + switch decision.Action { + case SignatureActionPreserve: + if decision.NormalizedSignature != "" && decision.NormalizedSignature != sigResult.String() { + updated, _ = sjson.Set(updated, sigPath, decision.NormalizedSignature) + changed = true + } + case SignatureActionReplaceWithGeminiBypass: + updated, _ = sjson.Set(updated, sigPath, decision.ReplacementSignature) + changed = true + default: + updated, _ = sjson.Delete(updated, sigPath) + changed = true + } + } + + if cleaned, ok := deleteEmptyJSONObjectPath(updated, "extra_content.google"); ok { + updated = cleaned + changed = true + } + if cleaned, ok := deleteEmptyJSONObjectPath(updated, "extra_content"); ok { + updated = cleaned + changed = true + } + + return updated, changed, decisions +} + +func claudeToolUseSignaturePaths() []string { + return []string{ + "signature", + "thought_signature", + "extra_content.google.thought_signature", + } +} + +func deleteEmptyJSONObjectPath(raw, path string) (string, bool) { + result := gjson.Get(raw, path) + if !result.Exists() || !result.IsObject() || len(result.Map()) != 0 { + return raw, false + } + updated, err := sjson.Delete(raw, path) + if err != nil { + return raw, false + } + return updated, true +} diff --git a/internal/signature/claude_test.go b/internal/signature/claude_test.go new file mode 100644 index 000000000..4c929dc21 --- /dev/null +++ b/internal/signature/claude_test.go @@ -0,0 +1,161 @@ +package signature + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestStripInvalidClaudeThinkingBlocks_RemovesGPTEncryptedContent(t *testing.T) { + input := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"codex reasoning","signature":"gAAAAABopenai-encrypted-content"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + out := StripInvalidClaudeThinkingBlocks(input) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("messages.0.content length = %d, want 1: %s", len(content), string(out)) + } + if got := content[0].Get("text").String(); got != "Answer" { + t.Fatalf("remaining content text = %q, want Answer", got) + } + if strings.Contains(string(out), "gAAAAABopenai-encrypted-content") || strings.Contains(string(out), "codex reasoning") { + t.Fatalf("invalid thinking block was preserved: %s", string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocksAndEmptyMessages_DropsMessagesLeftEmpty(t *testing.T) { + input := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"codex reasoning","signature":"gAAAAABopenai-encrypted-content"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + out := StripInvalidClaudeThinkingBlocksAndEmptyMessages(input) + messages := gjson.GetBytes(out, "messages").Array() + if len(messages) != 1 { + t.Fatalf("messages length = %d, want 1: %s", len(messages), string(out)) + } + if got := messages[0].Get("role").String(); got != "user" { + t.Fatalf("remaining role = %q, want user", got) + } + if strings.Contains(string(out), "gAAAAABopenai-encrypted-content") || strings.Contains(string(out), "codex reasoning") { + t.Fatalf("invalid thinking block was preserved: %s", string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_RemovesMalformedEPrefix(t *testing.T) { + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","thinking":"bad","signature":"Ebad"}, + {"type":"text","text":"Answer"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("content length = %d, want 1: %s", len(content), string(out)) + } + if strings.Contains(string(out), "Ebad") || strings.Contains(string(out), "bad") { + t.Fatalf("malformed E-prefix thinking block was preserved: %s", string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_Base64OnlyKeepsDecodableEPrefix(t *testing.T) { + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","thinking":"bad","signature":"Ebad"}, + {"type":"text","text":"Answer"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input, ClaudeSignatureValidationOptions{Base64Only: true}) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 2 { + t.Fatalf("content length = %d, want 2: %s", len(content), string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_Base64OnlyRemovesInvalidBase64(t *testing.T) { + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","thinking":"bad","signature":"E!!!invalid!!!"}, + {"type":"text","text":"Answer"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input, ClaudeSignatureValidationOptions{Base64Only: true}) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("content length = %d, want 1: %s", len(content), string(out)) + } + if strings.Contains(string(out), "E!!!invalid!!!") || strings.Contains(string(out), "bad") { + t.Fatalf("invalid-base64 thinking block was preserved: %s", string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_AllowsEmptySignatureEmptyTextPlaceholder(t *testing.T) { + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","text":"","signature":""}, + {"type":"text","text":"Answer"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input, ClaudeSignatureValidationOptions{ + Base64Only: true, + AllowEmptySignatureWithEmptyText: true, + }) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 2 { + t.Fatalf("content length = %d, want 2: %s", len(content), string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_StrictRemovesMalformedClaudeTree(t *testing.T) { + sig := base64.StdEncoding.EncodeToString([]byte{0x12, 0xFF, 0xFE, 0xFD}) + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","thinking":"bad","signature":"` + sig + `"}, + {"type":"text","text":"Answer"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input, ClaudeSignatureValidationOptions{Strict: true}) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("content length = %d, want 1: %s", len(content), string(out)) + } + if strings.Contains(string(out), sig) || strings.Contains(string(out), "bad") { + t.Fatalf("strict-invalid thinking block was preserved: %s", string(out)) + } +} + +func TestStripInvalidClaudeThinkingBlocks_KeepsClaudeSignaturePrefixes(t *testing.T) { + singleLayer := base64.StdEncoding.EncodeToString([]byte{0x12, 0x34}) + doubleLayer := base64.StdEncoding.EncodeToString([]byte(singleLayer)) + input := []byte(`{ + "messages": [{"role":"assistant","content":[ + {"type":"thinking","thinking":"one","signature":"` + singleLayer + `"}, + {"type":"thinking","thinking":"two","signature":"modelGroup#` + doubleLayer + `"} + ]}] + }`) + + out := StripInvalidClaudeThinkingBlocks(input) + content := gjson.GetBytes(out, "messages.0.content").Array() + if len(content) != 2 { + t.Fatalf("content length = %d, want 2: %s", len(content), string(out)) + } +} diff --git a/internal/signature/claude_validation.go b/internal/signature/claude_validation.go new file mode 100644 index 000000000..4bad747ed --- /dev/null +++ b/internal/signature/claude_validation.go @@ -0,0 +1,484 @@ +// Claude thinking signature validation. +// +// Spec reference: SIGNATURE-CHANNEL-SPEC.md +// +// Encoding detection (Spec section 3) +// +// Claude signatures use base64 encoding in one or two layers. The raw string's +// first character determines the encoding depth. This is mathematically +// equivalent to the spec's "decode first, check byte" approach: +// +// - E prefix: single-layer, payload[0] == 0x12, first 6 bits = 000100, +// base64 index 4 = E. +// - R prefix: double-layer, inner[0] == E (0x45), first 6 bits = 010001, +// base64 index 17 = R. +// +// Valid signatures can be normalized to R-form (double-layer base64) before +// sending to the Antigravity backend. +// +// # Protobuf structure (Spec sections 4.1 and 4.2) in strict mode only +// +// After base64 decoding to raw bytes, the first byte must be 0x12: +// +// Top-level protobuf +// |- Field 2 (bytes): container -> extractClaudeBytesField(payload, 2) +// | |- Field 1 (bytes): channel block -> extractClaudeBytesField(container, 1) +// | | |- Field 1 (varint): channel_id [required] -> routing_class (11 | 12) +// | | |- Field 2 (varint): infra [optional] -> infrastructure_class (aws=1 | google=2) +// | | |- Field 3 (varint): version=2 -> skipped +// | | |- Field 5 (bytes): ECDSA sig -> skipped, per Spec section 11 +// | | |- Field 6 (bytes): model_text [optional] -> schema_features +// | | `- Field 7 (varint): unknown [optional] -> schema_features +// | |- Field 2 (bytes): nonce 12B -> skipped +// | |- Field 3 (bytes): session 12B -> skipped +// | |- Field 4 (bytes): SHA-384 48B -> skipped +// | `- Field 5 (bytes): metadata -> skipped, per Spec section 11 +// `- Field 3 (varint): =1 -> skipped +// +// Output dimensions (Spec section 8) +// +// routing_class: routing_class_11 | routing_class_12 | unknown +// infrastructure_class: infra_default (absent) | infra_aws (1) | infra_google (2) | infra_unknown +// schema_features: compact_schema (len 70-72, no f6/f7) | extended_model_tagged_schema (f6 exists) | unknown +// legacy_route_hint: only for ch=11, legacy_default_group | legacy_aws_group | legacy_vertex_direct/proxy +// +// # Compatibility +// +// Verified against all confirmed spec samples (Anthropic Max 20x, Azure, +// Vertex, Bedrock) and legacy ch=11 signatures. Both single-layer (E) and +// double-layer (R) encodings are supported. Historical cache-mode modelGroup# +// prefixes are stripped. +package signature + +import ( + "encoding/base64" + "fmt" + "strings" + "unicode/utf8" + + "github.com/tidwall/gjson" + "google.golang.org/protobuf/encoding/protowire" +) + +const MaxClaudeThinkingSignatureLen = 32 * 1024 * 1024 + +// ClaudeSignatureValidationOptions controls how far Claude thinking signatures +// are inspected. The base validation always checks the cache prefix, base64 +// layers, and decoded 0x12 Claude payload marker. Strict mode additionally +// verifies the known protobuf tree used by Claude thinking signatures. +type ClaudeSignatureValidationOptions struct { + // PrefixOnly only checks for an optional cache prefix followed by an E/R + // Claude signature prefix. Use it to preserve legacy shallow cleanup. + PrefixOnly bool + // Base64Only checks the optional cache prefix, E/R Claude signature prefix, + // and base64 layers without validating the decoded Claude marker or protobuf + // tree. Use it for conservative request cleanup. + Base64Only bool + // AllowEmptySignatureWithEmptyText preserves empty thinking placeholders with + // no signature and no thinking/text payload during strip operations. + AllowEmptySignatureWithEmptyText bool + Strict bool +} + +// ClaudeSignatureTree describes the protobuf fields currently used for Claude +// thinking signature routing. +type ClaudeSignatureTree struct { + EncodingLayers int + ChannelID uint64 + Field2 *uint64 + RoutingClass string + InfrastructureClass string + SchemaFeatures string + ModelText string + LegacyRouteHint string + HasField7 bool +} + +func claudeSignatureValidationOptions(opts []ClaudeSignatureValidationOptions) ClaudeSignatureValidationOptions { + if len(opts) == 0 { + return ClaudeSignatureValidationOptions{} + } + return opts[0] +} + +// IsValidClaudeThinkingSignature returns whether rawSignature is a valid Claude +// thinking signature under the requested validation options. +func IsValidClaudeThinkingSignature(rawSignature string, opts ...ClaudeSignatureValidationOptions) bool { + opt := claudeSignatureValidationOptions(opts) + if opt.PrefixOnly { + return HasClaudeThinkingSignaturePrefix(rawSignature) + } + if opt.Base64Only { + return HasDecodableClaudeThinkingSignature(rawSignature) + } + _, err := NormalizeClaudeThinkingSignature(rawSignature, opts...) + return err == nil +} + +// HasDecodableClaudeThinkingSignature reports whether rawSignature has the +// Claude E/R shape and its expected base64 layer(s) can be decoded. +func HasDecodableClaudeThinkingSignature(rawSignature string) bool { + sig := stripClaudeSignaturePrefix(rawSignature) + if sig == "" || len(sig) > MaxClaudeThinkingSignatureLen { + return false + } + + switch sig[0] { + case 'E': + decoded, err := base64.StdEncoding.DecodeString(sig) + return err == nil && len(decoded) > 0 + case 'R': + decoded, err := base64.StdEncoding.DecodeString(sig) + if err != nil || len(decoded) == 0 || decoded[0] != 'E' { + return false + } + innerDecoded, err := base64.StdEncoding.DecodeString(string(decoded)) + return err == nil && len(innerDecoded) > 0 + default: + return false + } +} + +// HasClaudeThinkingSignaturePrefix reports whether rawSignature has the Claude +// E/R signature prefix after stripping an optional cache prefix. +func HasClaudeThinkingSignaturePrefix(rawSignature string) bool { + sig := stripClaudeSignaturePrefix(rawSignature) + if sig == "" { + return false + } + return sig[0] == 'E' || sig[0] == 'R' +} + +func stripClaudeSignaturePrefix(rawSignature string) string { + sig := strings.TrimSpace(rawSignature) + if sig == "" { + return "" + } + if idx := strings.IndexByte(sig, '#'); idx >= 0 { + sig = strings.TrimSpace(sig[idx+1:]) + } + return sig +} + +// ValidateClaudeThinkingSignatures validates every thinking block signature in a +// Claude messages payload. +func ValidateClaudeThinkingSignatures(inputRawJSON []byte, opts ...ClaudeSignatureValidationOptions) error { + messages := gjson.GetBytes(inputRawJSON, "messages") + if !messages.IsArray() { + return nil + } + + opt := claudeSignatureValidationOptions(opts) + messageResults := messages.Array() + for i := 0; i < len(messageResults); i++ { + contentResults := messageResults[i].Get("content") + if !contentResults.IsArray() { + continue + } + parts := contentResults.Array() + for j := 0; j < len(parts); j++ { + part := parts[j] + if part.Get("type").String() != "thinking" { + continue + } + + rawSignature := strings.TrimSpace(part.Get("signature").String()) + if rawSignature == "" { + return fmt.Errorf("messages[%d].content[%d]: missing thinking signature", i, j) + } + + if _, err := NormalizeClaudeThinkingSignature(rawSignature, opt); err != nil { + return fmt.Errorf("messages[%d].content[%d]: %w", i, j, err) + } + } + } + + return nil +} + +// NormalizeClaudeThinkingSignature strips any cache prefix, validates the +// signature, and returns the double-layer R-form expected by Antigravity bypass +// mode. +func NormalizeClaudeThinkingSignature(rawSignature string, opts ...ClaudeSignatureValidationOptions) (string, error) { + opt := claudeSignatureValidationOptions(opts) + sig := stripClaudeSignaturePrefix(rawSignature) + if sig == "" { + return "", fmt.Errorf("empty signature") + } + + if len(sig) > MaxClaudeThinkingSignatureLen { + return "", fmt.Errorf("signature exceeds maximum length (%d bytes)", MaxClaudeThinkingSignatureLen) + } + + switch sig[0] { + case 'R': + if err := validateClaudeDoubleLayerSignature(sig, opt); err != nil { + return "", err + } + return sig, nil + case 'E': + if err := validateClaudeSingleLayerSignature(sig, opt); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString([]byte(sig)), nil + default: + return "", fmt.Errorf("invalid signature: expected 'E' or 'R' prefix, got %q", string(sig[0])) + } +} + +func validateClaudeDoubleLayerSignature(sig string, opt ClaudeSignatureValidationOptions) error { + decoded, err := base64.StdEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("invalid double-layer signature: base64 decode failed: %w", err) + } + if len(decoded) == 0 { + return fmt.Errorf("invalid double-layer signature: empty after decode") + } + if decoded[0] != 'E' { + return fmt.Errorf("invalid double-layer signature: inner does not start with 'E', got 0x%02x", decoded[0]) + } + return validateClaudeSingleLayerSignatureContent(string(decoded), 2, opt) +} + +func validateClaudeSingleLayerSignature(sig string, opt ClaudeSignatureValidationOptions) error { + return validateClaudeSingleLayerSignatureContent(sig, 1, opt) +} + +func validateClaudeSingleLayerSignatureContent(sig string, encodingLayers int, opt ClaudeSignatureValidationOptions) error { + decoded, err := base64.StdEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("invalid single-layer signature: base64 decode failed: %w", err) + } + if len(decoded) == 0 { + return fmt.Errorf("invalid single-layer signature: empty after decode") + } + if decoded[0] != 0x12 { + return fmt.Errorf("invalid Claude signature: expected first byte 0x12, got 0x%02x", decoded[0]) + } + if !opt.Strict { + return nil + } + _, err = InspectClaudeSignaturePayload(decoded, encodingLayers) + return err +} + +// InspectClaudeDoubleLayerSignature decodes and inspects a double-layer Claude +// thinking signature. +func InspectClaudeDoubleLayerSignature(sig string) (*ClaudeSignatureTree, error) { + decoded, err := base64.StdEncoding.DecodeString(sig) + if err != nil { + return nil, fmt.Errorf("invalid double-layer signature: base64 decode failed: %w", err) + } + if len(decoded) == 0 { + return nil, fmt.Errorf("invalid double-layer signature: empty after decode") + } + if decoded[0] != 'E' { + return nil, fmt.Errorf("invalid double-layer signature: inner does not start with 'E', got 0x%02x", decoded[0]) + } + return inspectClaudeSingleLayerSignatureWithLayers(string(decoded), 2) +} + +// InspectClaudeSingleLayerSignature decodes and inspects a single-layer Claude +// thinking signature. +func InspectClaudeSingleLayerSignature(sig string) (*ClaudeSignatureTree, error) { + return inspectClaudeSingleLayerSignatureWithLayers(sig, 1) +} + +func inspectClaudeSingleLayerSignatureWithLayers(sig string, encodingLayers int) (*ClaudeSignatureTree, error) { + decoded, err := base64.StdEncoding.DecodeString(sig) + if err != nil { + return nil, fmt.Errorf("invalid single-layer signature: base64 decode failed: %w", err) + } + if len(decoded) == 0 { + return nil, fmt.Errorf("invalid single-layer signature: empty after decode") + } + return InspectClaudeSignaturePayload(decoded, encodingLayers) +} + +// InspectClaudeSignaturePayload inspects the decoded Claude thinking signature +// protobuf payload. +func InspectClaudeSignaturePayload(payload []byte, encodingLayers int) (*ClaudeSignatureTree, error) { + if len(payload) == 0 { + return nil, fmt.Errorf("invalid Claude signature: empty payload") + } + if payload[0] != 0x12 { + return nil, fmt.Errorf("invalid Claude signature: expected first byte 0x12, got 0x%02x", payload[0]) + } + container, err := extractClaudeBytesField(payload, 2, "top-level protobuf") + if err != nil { + return nil, err + } + channelBlock, err := extractClaudeBytesField(container, 1, "Claude Field 2 container") + if err != nil { + return nil, err + } + return inspectClaudeChannelBlock(channelBlock, encodingLayers) +} + +func inspectClaudeChannelBlock(channelBlock []byte, encodingLayers int) (*ClaudeSignatureTree, error) { + tree := &ClaudeSignatureTree{ + EncodingLayers: encodingLayers, + RoutingClass: "unknown", + InfrastructureClass: "infra_unknown", + SchemaFeatures: "unknown_schema_features", + } + haveChannelID := false + hasField6 := false + hasField7 := false + + err := walkClaudeProtobufFields(channelBlock, func(num protowire.Number, typ protowire.Type, raw []byte) error { + switch num { + case 1: + if typ != protowire.VarintType { + return fmt.Errorf("invalid Claude signature: Field 2.1.1 channel_id must be varint") + } + channelID, err := decodeClaudeVarintField(raw, "Field 2.1.1 channel_id") + if err != nil { + return err + } + tree.ChannelID = channelID + haveChannelID = true + case 2: + if typ != protowire.VarintType { + return fmt.Errorf("invalid Claude signature: Field 2.1.2 field2 must be varint") + } + field2, err := decodeClaudeVarintField(raw, "Field 2.1.2 field2") + if err != nil { + return err + } + tree.Field2 = &field2 + case 6: + if typ != protowire.BytesType { + return fmt.Errorf("invalid Claude signature: Field 2.1.6 model_text must be bytes") + } + modelBytes, err := decodeClaudeBytesField(raw, "Field 2.1.6 model_text") + if err != nil { + return err + } + if !utf8.Valid(modelBytes) { + return fmt.Errorf("invalid Claude signature: Field 2.1.6 model_text is not valid UTF-8") + } + tree.ModelText = string(modelBytes) + hasField6 = true + case 7: + if typ != protowire.VarintType { + return fmt.Errorf("invalid Claude signature: Field 2.1.7 must be varint") + } + if _, err := decodeClaudeVarintField(raw, "Field 2.1.7"); err != nil { + return err + } + hasField7 = true + tree.HasField7 = true + } + return nil + }) + if err != nil { + return nil, err + } + if !haveChannelID { + return nil, fmt.Errorf("invalid Claude signature: missing Field 2.1.1 channel_id") + } + + switch tree.ChannelID { + case 11: + tree.RoutingClass = "routing_class_11" + case 12: + tree.RoutingClass = "routing_class_12" + } + + if tree.Field2 == nil { + tree.InfrastructureClass = "infra_default" + } else { + switch *tree.Field2 { + case 1: + tree.InfrastructureClass = "infra_aws" + case 2: + tree.InfrastructureClass = "infra_google" + default: + tree.InfrastructureClass = "infra_unknown" + } + } + + switch { + case hasField6: + tree.SchemaFeatures = "extended_model_tagged_schema" + case !hasField6 && !hasField7 && len(channelBlock) >= 70 && len(channelBlock) <= 72: + tree.SchemaFeatures = "compact_schema" + } + + if tree.ChannelID == 11 { + switch { + case tree.Field2 == nil: + tree.LegacyRouteHint = "legacy_default_group" + case *tree.Field2 == 1: + tree.LegacyRouteHint = "legacy_aws_group" + case *tree.Field2 == 2 && tree.EncodingLayers == 2: + tree.LegacyRouteHint = "legacy_vertex_direct" + case *tree.Field2 == 2 && tree.EncodingLayers == 1: + tree.LegacyRouteHint = "legacy_vertex_proxy" + } + } + + return tree, nil +} + +func extractClaudeBytesField(msg []byte, fieldNum protowire.Number, scope string) ([]byte, error) { + var value []byte + err := walkClaudeProtobufFields(msg, func(num protowire.Number, typ protowire.Type, raw []byte) error { + if num != fieldNum { + return nil + } + if typ != protowire.BytesType { + return fmt.Errorf("invalid Claude signature: %s field %d must be bytes", scope, fieldNum) + } + bytesValue, err := decodeClaudeBytesField(raw, fmt.Sprintf("%s field %d", scope, fieldNum)) + if err != nil { + return err + } + value = bytesValue + return nil + }) + if err != nil { + return nil, err + } + if value == nil { + return nil, fmt.Errorf("invalid Claude signature: missing %s field %d", scope, fieldNum) + } + return value, nil +} + +func walkClaudeProtobufFields(msg []byte, visit func(num protowire.Number, typ protowire.Type, raw []byte) error) error { + for offset := 0; offset < len(msg); { + num, typ, n := protowire.ConsumeTag(msg[offset:]) + if n < 0 { + return fmt.Errorf("invalid Claude signature: malformed protobuf tag: %w", protowire.ParseError(n)) + } + offset += n + valueLen := protowire.ConsumeFieldValue(num, typ, msg[offset:]) + if valueLen < 0 { + return fmt.Errorf("invalid Claude signature: malformed protobuf field %d: %w", num, protowire.ParseError(valueLen)) + } + fieldRaw := msg[offset : offset+valueLen] + if err := visit(num, typ, fieldRaw); err != nil { + return err + } + offset += valueLen + } + return nil +} + +func decodeClaudeVarintField(raw []byte, label string) (uint64, error) { + value, n := protowire.ConsumeVarint(raw) + if n < 0 { + return 0, fmt.Errorf("invalid Claude signature: failed to decode %s: %w", label, protowire.ParseError(n)) + } + return value, nil +} + +func decodeClaudeBytesField(raw []byte, label string) ([]byte, error) { + value, n := protowire.ConsumeBytes(raw) + if n < 0 { + return nil, fmt.Errorf("invalid Claude signature: failed to decode %s: %w", label, protowire.ParseError(n)) + } + return value, nil +} diff --git a/internal/signature/gemini_validation.go b/internal/signature/gemini_validation.go new file mode 100644 index 000000000..d3a655112 --- /dev/null +++ b/internal/signature/gemini_validation.go @@ -0,0 +1,497 @@ +// Gemini thought signature validation notes. +// +// The Antigravity Gemini request translator can preserve provider-compatible +// Gemini thought signatures and uses the skip sentinel only for synthetic or +// incompatible model parts. +// +// Gemini 3 and later models can return thoughtSignature on model content parts. +// Function-call parts are the strict case: when a model functionCall is replayed +// with a following functionResponse, Gemini validates that the original +// functionCall part still carries its provider-issued thoughtSignature. Text or +// other non-functionCall parts may also carry a signature; those should be +// preserved when replaying native Gemini history, but they are not the primary +// validation gate. +// +// Synthetic history and migration from other model families are different. If a +// functionCall part was not produced by Gemini API, there is no real signature +// to preserve. Gemini documents two bypass sentinels for that case: +// +// - "skip_thought_signature_validator" +// - "context_engineering_is_the_way_to_go" +// +// This repo currently emits "skip_thought_signature_validator" for non-Claude +// Antigravity Gemini model parts that contain functionCall, thought, or an +// existing thoughtSignature. That is a request-shape compatibility policy, not a +// proof that the replaced signature was malformed. +// +// This validator is intentionally more conservative than a decrypting verifier. +// Claude has a known E/R base64 envelope and a protobuf tree in this package. +// Gemini thought signatures are opaque provider state here, so local validation +// checks only the transport-level protobuf envelope and leaves the wrapped +// provider payload uninterpreted. +// +// Validation tiers: +// +// - Sentinel tier: accept the documented bypass sentinels only when the +// model functionCall is synthetic, migrated, or otherwise not traceable to a +// prior Gemini model response in the same conversation. +// - Opaque-shape tier: for real Gemini signatures, require a non-empty string, +// bounded length, successful standard base64 decoding, and a known protobuf +// envelope when the caller needs provider compatibility. Observed samples +// currently include Gemini 3.x field-2 -> field-1 payloads and Gemini 2.5 +// repeated field-1 payloads. Base64 UUID payloads are classified separately +// and should be replaced with the bypass sentinel rather than replayed. +// - Replay tier: real validation means preserving the exact model part that +// came from Gemini, including its thoughtSignature, id/name/function args, +// part index, and ordering relative to sibling parallel function calls. +// - Tool pairing tier: functionResponse parts must match the preceding +// functionCall id/name and must not be interleaved between parallel calls. +// The valid shape is all model functionCalls first, then their responses. +// - Compatibility tier: GPT-compatible Gemini traffic stores the same state +// under tool_calls[].extra_content.google.thought_signature. If that path is +// translated back to native Gemini, the value must stay attached to the same +// assistant tool call. +// +// Important non-goals: +// +// - Do not treat a Gemini thoughtSignature as a Claude signature. Similar +// base64 prefixes are not provenance. +// - Do not attach a signature to user functionResponse/tool-result parts. +// - Do not log complete signatures during validation failures; log only field +// paths, lengths, and redacted prefixes. +// - Do not preserve client-provided signatures across model/provider/session +// boundaries unless the request pipeline can prove they came from the same +// Gemini conversation state. +package signature + +import ( + "encoding/base64" + "fmt" + "strings" + + "github.com/tidwall/gjson" + "google.golang.org/protobuf/encoding/protowire" +) + +const ( + MaxGeminiThoughtSignatureLen = 32 * 1024 * 1024 + + GeminiSkipThoughtSignatureValidator = "skip_thought_signature_validator" + GeminiContextEngineeringBypass = "context_engineering_is_the_way_to_go" +) + +// GeminiThoughtSignatureValidationOptions controls how much local validation is +// applied to Gemini thought signatures. This validation checks only the opaque +// transport envelope; it does not prove that a signature came from Gemini or can +// be decrypted by Gemini. +type GeminiThoughtSignatureValidationOptions struct { + // AllowBypassSentinel accepts Gemini's documented synthetic-history bypass + // sentinels. Keep this false when validating provider-issued signatures. + AllowBypassSentinel bool + // RequireKnownEnvelope requires the decoded payload to match one of the + // protobuf envelopes observed in Gemini samples. This rejects opaque base64 + // values such as base64 UUIDs. + RequireKnownEnvelope bool + // RequireObservedMarker requires the decoded payload to start with 0x12. + // Current Gemini 3.x samples show this marker, but Gemini 2.5 samples use a + // different protobuf prefix, so this should be used only for narrow Gemini 3 + // experiments. + RequireObservedMarker bool +} + +type GeminiThoughtSignatureEnvelope string + +const ( + GeminiThoughtSignatureEnvelopeUnknown GeminiThoughtSignatureEnvelope = "unknown" + GeminiThoughtSignatureEnvelopeProtobufField1 GeminiThoughtSignatureEnvelope = "protobuf_field_1" + GeminiThoughtSignatureEnvelopeProtobufField2 GeminiThoughtSignatureEnvelope = "protobuf_field_2" + GeminiThoughtSignatureEnvelopeASCIIUUID GeminiThoughtSignatureEnvelope = "ascii_uuid" +) + +// GeminiThoughtSignatureInfo describes the locally inspectable properties of an +// opaque Gemini thought signature. +type GeminiThoughtSignatureInfo struct { + IsBypassSentinel bool + BypassSentinel string + DecodedLen int + FirstByte byte + HasObservedMarker bool + KnownEnvelope bool + Envelope GeminiThoughtSignatureEnvelope + RecordCount int + OpaquePayloadLen int +} + +type geminiFunctionCallRef struct { + id string + name string + path string +} + +type geminiFunctionResponseRef struct { + part gjson.Result + path string +} + +func geminiThoughtSignatureValidationOptions(opts []GeminiThoughtSignatureValidationOptions) GeminiThoughtSignatureValidationOptions { + if len(opts) == 0 { + return GeminiThoughtSignatureValidationOptions{} + } + return opts[0] +} + +// IsGeminiThoughtSignatureBypass reports whether rawSignature is one of +// Gemini's documented bypass sentinels for synthetic or migrated function-call +// history. +func IsGeminiThoughtSignatureBypass(rawSignature string) bool { + switch strings.TrimSpace(rawSignature) { + case GeminiSkipThoughtSignatureValidator, GeminiContextEngineeringBypass: + return true + default: + return false + } +} + +// IsValidGeminiThoughtSignature returns whether rawSignature has a valid local +// Gemini thought-signature shape under opts. +func IsValidGeminiThoughtSignature(rawSignature string, opts ...GeminiThoughtSignatureValidationOptions) bool { + _, err := InspectGeminiThoughtSignature(rawSignature, opts...) + return err == nil +} + +// InspectGeminiThoughtSignature validates and inspects the local transport +// shape of a Gemini thought signature. It intentionally treats provider-issued +// signatures as opaque base64 payloads. +func InspectGeminiThoughtSignature(rawSignature string, opts ...GeminiThoughtSignatureValidationOptions) (*GeminiThoughtSignatureInfo, error) { + opt := geminiThoughtSignatureValidationOptions(opts) + sig := strings.TrimSpace(rawSignature) + if sig == "" { + return nil, fmt.Errorf("empty Gemini thought signature") + } + + if IsGeminiThoughtSignatureBypass(sig) { + if !opt.AllowBypassSentinel { + return nil, fmt.Errorf("Gemini thought signature bypass sentinel is not allowed") + } + return &GeminiThoughtSignatureInfo{ + IsBypassSentinel: true, + BypassSentinel: sig, + }, nil + } + + decoded, err := decodeGeminiThoughtSignature(sig) + if err != nil { + return nil, err + } + if len(decoded) == 0 { + return nil, fmt.Errorf("invalid Gemini thought signature: empty decoded payload") + } + + info := &GeminiThoughtSignatureInfo{ + DecodedLen: len(decoded), + FirstByte: decoded[0], + HasObservedMarker: decoded[0] == 0x12, + } + info.Envelope, info.KnownEnvelope = classifyGeminiThoughtSignatureEnvelope(decoded) + info.RecordCount, info.OpaquePayloadLen = inspectGeminiEnvelope(decoded, info.Envelope) + if opt.RequireKnownEnvelope && !info.KnownEnvelope { + return nil, fmt.Errorf("invalid Gemini thought signature: unknown envelope %q", info.Envelope) + } + if opt.RequireObservedMarker && !info.HasObservedMarker { + return nil, fmt.Errorf("invalid Gemini thought signature: expected observed marker 0x12, got 0x%02x", info.FirstByte) + } + + return info, nil +} + +// ValidateGeminiThoughtSignatures validates thoughtSignature fields in a Gemini +// native payload. Function-call parts must have a valid signature. Other parts +// are optional, but if a thoughtSignature field is present it must be valid. +func ValidateGeminiThoughtSignatures(inputRawJSON []byte, opts ...GeminiThoughtSignatureValidationOptions) error { + contents, contentsPath := geminiContents(inputRawJSON) + if !contents.IsArray() { + return nil + } + + contentResults := contents.Array() + for i := 0; i < len(contentResults); i++ { + parts := contentResults[i].Get("parts") + if !parts.IsArray() { + continue + } + + partResults := parts.Array() + for j := 0; j < len(partResults); j++ { + part := partResults[j] + hasFunctionCall := part.Get("functionCall").Exists() + hasSignature := part.Get("thoughtSignature").Exists() + if !hasFunctionCall && !hasSignature { + continue + } + + partPath := fmt.Sprintf("%s[%d].parts[%d]", contentsPath, i, j) + rawSignature := strings.TrimSpace(part.Get("thoughtSignature").String()) + if rawSignature == "" { + if hasFunctionCall { + return fmt.Errorf("%s: missing thoughtSignature on functionCall", partPath) + } + return fmt.Errorf("%s: empty thoughtSignature", partPath) + } + + if _, err := InspectGeminiThoughtSignature(rawSignature, opts...); err != nil { + return fmt.Errorf("%s: %w", partPath, err) + } + } + } + + return nil +} + +// ValidateGeminiFunctionCallPairing validates the replay shape around Gemini +// functionCall and functionResponse parts. It checks id/name pairing and +// prevents response parts from being interleaved inside the same content as +// function calls. It allows a final pending functionCall group because callers +// may validate a freshly returned model step before tool outputs exist. +func ValidateGeminiFunctionCallPairing(inputRawJSON []byte) error { + contents, contentsPath := geminiContents(inputRawJSON) + if !contents.IsArray() { + return nil + } + + var pending []geminiFunctionCallRef + contentResults := contents.Array() + for i := 0; i < len(contentResults); i++ { + parts := contentResults[i].Get("parts") + if !parts.IsArray() { + continue + } + + var calls []geminiFunctionCallRef + var responses []geminiFunctionResponseRef + partResults := parts.Array() + for j := 0; j < len(partResults); j++ { + part := partResults[j] + partPath := fmt.Sprintf("%s[%d].parts[%d]", contentsPath, i, j) + if call := part.Get("functionCall"); call.Exists() { + if call.Get("name").String() == "" { + return fmt.Errorf("%s: missing functionCall.name", partPath) + } + calls = append(calls, geminiFunctionCallRef{ + id: call.Get("id").String(), + name: call.Get("name").String(), + path: partPath, + }) + } + if response := part.Get("functionResponse"); response.Exists() { + responses = append(responses, geminiFunctionResponseRef{ + part: part, + path: partPath, + }) + } + } + + if len(calls) > 0 && len(responses) > 0 { + return fmt.Errorf("%s[%d]: functionCall and functionResponse parts must not be interleaved in the same content", contentsPath, i) + } + + if len(calls) > 0 { + if len(pending) > 0 { + return fmt.Errorf("%s[%d]: functionCall appears before %d pending functionResponse part(s)", contentsPath, i, len(pending)) + } + pending = calls + continue + } + + if len(responses) == 0 { + continue + } + if len(pending) == 0 { + return fmt.Errorf("%s[%d]: functionResponse without preceding functionCall", contentsPath, i) + } + if len(responses) != len(pending) { + return fmt.Errorf("%s[%d]: functionResponse count %d does not match pending functionCall count %d", contentsPath, i, len(responses), len(pending)) + } + + for j := 0; j < len(responses); j++ { + partPath := responses[j].path + response := responses[j].part.Get("functionResponse") + call := pending[j] + responseID := response.Get("id").String() + responseName := response.Get("name").String() + + if call.id != "" && responseID == "" { + return fmt.Errorf("%s: missing functionResponse.id for %s", partPath, call.path) + } + if call.id != "" && responseID != call.id { + return fmt.Errorf("%s: functionResponse.id %q does not match functionCall.id %q at %s", partPath, responseID, call.id, call.path) + } + if responseName == "" { + return fmt.Errorf("%s: missing functionResponse.name", partPath) + } + if call.name != "" && responseName != call.name { + return fmt.Errorf("%s: functionResponse.name %q does not match functionCall.name %q at %s", partPath, responseName, call.name, call.path) + } + } + + pending = nil + } + + return nil +} + +func decodeGeminiThoughtSignature(sig string) ([]byte, error) { + if len(sig) > MaxGeminiThoughtSignatureLen { + return nil, fmt.Errorf("Gemini thought signature exceeds maximum length (%d bytes)", MaxGeminiThoughtSignatureLen) + } + + decoded, err := base64.StdEncoding.DecodeString(sig) + if err == nil { + return decoded, nil + } + if decoded, rawErr := base64.RawStdEncoding.DecodeString(sig); rawErr == nil { + return decoded, nil + } + + return nil, fmt.Errorf("invalid Gemini thought signature: base64 decode failed: %w", err) +} + +func classifyGeminiThoughtSignatureEnvelope(decoded []byte) (GeminiThoughtSignatureEnvelope, bool) { + if len(decoded) == 0 { + return GeminiThoughtSignatureEnvelopeUnknown, false + } + if isASCIIUUIDBytes(decoded) { + return GeminiThoughtSignatureEnvelopeASCIIUUID, false + } + switch { + case isGeminiField1Envelope(decoded): + return GeminiThoughtSignatureEnvelopeProtobufField1, true + case isGeminiField2Envelope(decoded): + return GeminiThoughtSignatureEnvelopeProtobufField2, true + default: + return GeminiThoughtSignatureEnvelopeUnknown, false + } +} + +func isGeminiField1Envelope(decoded []byte) bool { + info, ok := inspectGeminiField1Envelope(decoded) + return ok && info.RecordCount > 0 +} + +func isGeminiField2Envelope(decoded []byte) bool { + info, ok := inspectGeminiField2Envelope(decoded) + return ok && info.RecordCount == 1 && info.OpaquePayloadLen > 0 +} + +func inspectGeminiEnvelope(decoded []byte, envelope GeminiThoughtSignatureEnvelope) (recordCount int, opaquePayloadLen int) { + switch envelope { + case GeminiThoughtSignatureEnvelopeProtobufField1: + if info, ok := inspectGeminiField1Envelope(decoded); ok { + return info.RecordCount, info.OpaquePayloadLen + } + case GeminiThoughtSignatureEnvelopeProtobufField2: + if info, ok := inspectGeminiField2Envelope(decoded); ok { + return info.RecordCount, info.OpaquePayloadLen + } + } + return 0, 0 +} + +type geminiEnvelopeInfo struct { + RecordCount int + OpaquePayloadLen int +} + +func inspectGeminiField1Envelope(decoded []byte) (geminiEnvelopeInfo, bool) { + var info geminiEnvelopeInfo + offset := 0 + for offset < len(decoded) { + num, typ, n := protowire.ConsumeTag(decoded[offset:]) + if n < 0 || num != 1 || typ != protowire.BytesType { + return geminiEnvelopeInfo{}, false + } + offset += n + value, n := protowire.ConsumeBytes(decoded[offset:]) + if n < 0 || !isLikelyGeminiOpaquePayload(value) { + return geminiEnvelopeInfo{}, false + } + info.RecordCount++ + info.OpaquePayloadLen += len(value) + offset += n + } + return info, offset == len(decoded) && info.RecordCount > 0 +} + +func inspectGeminiField2Envelope(decoded []byte) (geminiEnvelopeInfo, bool) { + value, ok := consumeGeminiField2Field1Value(decoded) + if !ok || !isLikelyGeminiOpaquePayload(value) { + return geminiEnvelopeInfo{}, false + } + return geminiEnvelopeInfo{ + RecordCount: 1, + OpaquePayloadLen: len(value), + }, true +} + +func consumeGeminiField2Field1Value(decoded []byte) ([]byte, bool) { + num, typ, n := protowire.ConsumeTag(decoded) + if n < 0 || num != 2 || typ != protowire.BytesType { + return nil, false + } + offset := n + container, n := protowire.ConsumeBytes(decoded[offset:]) + if n < 0 { + return nil, false + } + offset += n + if offset != len(decoded) { + return nil, false + } + + num, typ, n = protowire.ConsumeTag(container) + if n < 0 || num != 1 || typ != protowire.BytesType { + return nil, false + } + containerOffset := n + value, n := protowire.ConsumeBytes(container[containerOffset:]) + if n < 0 { + return nil, false + } + containerOffset += n + if containerOffset != len(container) { + return nil, false + } + return value, true +} + +func isLikelyGeminiOpaquePayload(value []byte) bool { + // Observed Gemini 2.5 and Gemini 3.x envelopes wrap provider-opaque + // payloads that start with an internal version byte 0x01. The bytes after + // that are high-entropy provider state and must remain opaque. + return len(value) > 0 && value[0] == 0x01 +} + +func isASCIIUUIDBytes(decoded []byte) bool { + if len(decoded) != 36 { + return false + } + for i, b := range decoded { + switch i { + case 8, 13, 18, 23: + if b != '-' { + return false + } + default: + if !((b >= '0' && b <= '9') || (b >= 'a' && b <= 'f') || (b >= 'A' && b <= 'F')) { + return false + } + } + } + return true +} + +func geminiContents(inputRawJSON []byte) (gjson.Result, string) { + if contents := gjson.GetBytes(inputRawJSON, "contents"); contents.Exists() { + return contents, "contents" + } + return gjson.GetBytes(inputRawJSON, "request.contents"), "request.contents" +} diff --git a/internal/signature/gemini_validation_test.go b/internal/signature/gemini_validation_test.go new file mode 100644 index 000000000..add57a6b3 --- /dev/null +++ b/internal/signature/gemini_validation_test.go @@ -0,0 +1,393 @@ +package signature + +import ( + "encoding/base64" + "strings" + "testing" + + "google.golang.org/protobuf/encoding/protowire" +) + +func testGeminiThoughtSignature(payload []byte) string { + return base64.StdEncoding.EncodeToString(payload) +} + +func testGemini25ThoughtSignature(records ...[]byte) string { + var payload []byte + for _, record := range records { + payload = protowire.AppendTag(payload, 1, protowire.BytesType) + payload = protowire.AppendBytes(payload, record) + } + return testGeminiThoughtSignature(payload) +} + +func testGemini3ThoughtSignature(payload []byte) string { + var inner []byte + inner = protowire.AppendTag(inner, 1, protowire.BytesType) + inner = protowire.AppendBytes(inner, payload) + + var outer []byte + outer = protowire.AppendTag(outer, 2, protowire.BytesType) + outer = protowire.AppendBytes(outer, inner) + return testGeminiThoughtSignature(outer) +} + +func TestInspectGeminiThoughtSignature_AcceptsOpaqueBase64(t *testing.T) { + sig := testGeminiThoughtSignature([]byte{0x12, 0x34, 0x56}) + + info, err := InspectGeminiThoughtSignature(sig) + if err != nil { + t.Fatalf("InspectGeminiThoughtSignature failed: %v", err) + } + if info.IsBypassSentinel { + t.Fatal("real signature should not be marked as bypass sentinel") + } + if info.DecodedLen != 3 { + t.Fatalf("DecodedLen = %d, want 3", info.DecodedLen) + } + if info.FirstByte != 0x12 { + t.Fatalf("FirstByte = 0x%02x, want 0x12", info.FirstByte) + } + if !info.HasObservedMarker { + t.Fatal("HasObservedMarker should be true") + } + if info.Envelope != GeminiThoughtSignatureEnvelopeUnknown { + t.Fatalf("Envelope = %q, want %q", info.Envelope, GeminiThoughtSignatureEnvelopeUnknown) + } + if info.KnownEnvelope { + t.Fatal("KnownEnvelope should be false for incomplete opaque payload") + } +} + +func TestInspectGeminiThoughtSignature_AcceptsGemini31ProField2Envelope(t *testing.T) { + // Shape observed in CPA-API/signatures/gemini/gemini-3.1-pro.txt. + sig := testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39, 0xd6, 0xc7, 0x34}) + + info, err := InspectGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) + if err != nil { + t.Fatalf("Gemini 3.1 Pro field-2 envelope should be known: %v", err) + } + if info.Envelope != GeminiThoughtSignatureEnvelopeProtobufField2 { + t.Fatalf("Envelope = %q, want %q", info.Envelope, GeminiThoughtSignatureEnvelopeProtobufField2) + } + if !info.HasObservedMarker { + t.Fatal("Gemini 3.1 Pro envelope should be marked as 0x12") + } + if info.RecordCount != 1 { + t.Fatalf("RecordCount = %d, want 1", info.RecordCount) + } + if info.OpaquePayloadLen != 6 { + t.Fatalf("OpaquePayloadLen = %d, want 6", info.OpaquePayloadLen) + } +} + +func TestInspectGeminiThoughtSignature_AcceptsCapturedGemini31FlashLiteEnvelope(t *testing.T) { + // Captured in CPA-API/signatures/gemini/gemini-3.1-flash-lite.txt. + const sig = "EjQKMgEMOdbHO0Gd+c9Mxk4ELwPGbpCEcp2mFfYYLix2UVtBH3fL8GECc4+JITVnHF4qZDsA" + + info, err := InspectGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) + if err != nil { + t.Fatalf("captured Gemini 3.1 Flash Lite envelope should be known: %v", err) + } + if info.Envelope != GeminiThoughtSignatureEnvelopeProtobufField2 { + t.Fatalf("Envelope = %q, want %q", info.Envelope, GeminiThoughtSignatureEnvelopeProtobufField2) + } + if info.RecordCount != 1 { + t.Fatalf("RecordCount = %d, want 1", info.RecordCount) + } + if info.OpaquePayloadLen != 50 { + t.Fatalf("OpaquePayloadLen = %d, want 50", info.OpaquePayloadLen) + } +} + +func TestInspectGeminiThoughtSignature_AcceptsGemini25Field1Envelope(t *testing.T) { + sig := testGemini25ThoughtSignature([]byte{0x01, 0x8f}, []byte{0x01, 0x90, 0x91}) + + info, err := InspectGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) + if err != nil { + t.Fatalf("Gemini 2.5 field-1 envelope should be known: %v", err) + } + if info.Envelope != GeminiThoughtSignatureEnvelopeProtobufField1 { + t.Fatalf("Envelope = %q, want %q", info.Envelope, GeminiThoughtSignatureEnvelopeProtobufField1) + } + if info.HasObservedMarker { + t.Fatal("Gemini 2.5 field-1 envelope should not be marked as 0x12") + } + if info.RecordCount != 2 { + t.Fatalf("RecordCount = %d, want 2", info.RecordCount) + } + if info.OpaquePayloadLen != 5 { + t.Fatalf("OpaquePayloadLen = %d, want 5", info.OpaquePayloadLen) + } +} + +func TestInspectGeminiThoughtSignature_RejectsMalformedKnownEnvelope(t *testing.T) { + // Field 2 with a nested field 1 is not enough. Observed Gemini 3 payloads + // wrap an opaque blob that starts with internal version byte 0x01. + sig := testGemini3ThoughtSignature([]byte{0x02, 0x0c, 0x39}) + + if IsValidGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) { + t.Fatal("malformed Gemini 3 envelope should fail known-envelope validation") + } +} + +func TestInspectGeminiThoughtSignature_ClassifiesASCIIUUIDAsOpaque(t *testing.T) { + sig := testGeminiThoughtSignature([]byte("e24830a7-5cd6-42fe-998b-ee539e72b9c3")) + + info, err := InspectGeminiThoughtSignature(sig) + if err != nil { + t.Fatalf("opaque base64 UUID should pass default validation: %v", err) + } + if info.Envelope != GeminiThoughtSignatureEnvelopeASCIIUUID { + t.Fatalf("Envelope = %q, want %q", info.Envelope, GeminiThoughtSignatureEnvelopeASCIIUUID) + } + if info.KnownEnvelope { + t.Fatal("base64 UUID should not be a known protobuf envelope") + } + if IsValidGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) { + t.Fatal("base64 UUID should fail when known envelope is required") + } +} + +func TestInspectGeminiThoughtSignature_ObservedMarkerOption(t *testing.T) { + sig := testGeminiThoughtSignature([]byte{0x45, 0x12}) + + if _, err := InspectGeminiThoughtSignature(sig); err != nil { + t.Fatalf("default validation should accept opaque base64 payload: %v", err) + } + _, err := InspectGeminiThoughtSignature(sig, GeminiThoughtSignatureValidationOptions{RequireObservedMarker: true}) + if err == nil { + t.Fatal("RequireObservedMarker should reject payloads without 0x12 marker") + } + if !strings.Contains(err.Error(), "expected observed marker") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestInspectGeminiThoughtSignature_BypassSentinelRequiresOption(t *testing.T) { + if IsValidGeminiThoughtSignature(GeminiSkipThoughtSignatureValidator) { + t.Fatal("bypass sentinel should not be valid by default") + } + + info, err := InspectGeminiThoughtSignature(GeminiSkipThoughtSignatureValidator, GeminiThoughtSignatureValidationOptions{AllowBypassSentinel: true}) + if err != nil { + t.Fatalf("bypass sentinel should be accepted when explicitly allowed: %v", err) + } + if !info.IsBypassSentinel { + t.Fatal("sentinel should be marked as bypass") + } + if info.BypassSentinel != GeminiSkipThoughtSignatureValidator { + t.Fatalf("BypassSentinel = %q, want %q", info.BypassSentinel, GeminiSkipThoughtSignatureValidator) + } +} + +func TestInspectGeminiThoughtSignature_RejectsInvalidBase64(t *testing.T) { + if IsValidGeminiThoughtSignature("not valid base64!!!") { + t.Fatal("invalid base64 should be rejected") + } +} + +func TestValidateGeminiThoughtSignatures_FunctionCallRequiresSignature(t *testing.T) { + input := []byte(`{ + "contents": [{ + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "read_file", "args": {}}} + ] + }] + }`) + + err := ValidateGeminiThoughtSignatures(input) + if err == nil { + t.Fatal("missing functionCall thoughtSignature should fail") + } + if !strings.Contains(err.Error(), "missing thoughtSignature on functionCall") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiThoughtSignatures_AcceptsWrappedRequestAndSentinelWhenAllowed(t *testing.T) { + input := []byte(`{ + "request": { + "contents": [{ + "role": "model", + "parts": [ + { + "functionCall": {"id": "call-1", "name": "read_file", "args": {}}, + "thoughtSignature": "skip_thought_signature_validator" + } + ] + }] + } + }`) + + err := ValidateGeminiThoughtSignatures(input, GeminiThoughtSignatureValidationOptions{AllowBypassSentinel: true}) + if err != nil { + t.Fatalf("sentinel should be valid when explicitly allowed: %v", err) + } +} + +func TestValidateGeminiThoughtSignatures_RejectsInvalidTextPartSignature(t *testing.T) { + input := []byte(`{ + "contents": [{ + "role": "model", + "parts": [ + {"text": "previous answer", "thoughtSignature": "bad!!!"} + ] + }] + }`) + + err := ValidateGeminiThoughtSignatures(input) + if err == nil { + t.Fatal("invalid text-part thoughtSignature should fail") + } + if !strings.Contains(err.Error(), "base64 decode failed") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_ValidParallelGroup(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "weather", "args": {"city": "Paris"}}}, + {"functionCall": {"id": "call-2", "name": "weather", "args": {"city": "London"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"id": "call-1", "name": "weather", "response": {"temp": "15C"}}}, + {"functionResponse": {"id": "call-2", "name": "weather", "response": {"temp": "12C"}}} + ] + } + ] + }`) + + if err := ValidateGeminiFunctionCallPairing(input); err != nil { + t.Fatalf("valid pairing failed: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_RejectsResponseCountMismatch(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "weather", "args": {}}}, + {"functionCall": {"id": "call-2", "name": "weather", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"id": "call-1", "name": "weather", "response": {}}} + ] + } + ] + }`) + + err := ValidateGeminiFunctionCallPairing(input) + if err == nil { + t.Fatal("response count mismatch should fail") + } + if !strings.Contains(err.Error(), "does not match pending functionCall count") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_RejectsMissingFunctionCallName(t *testing.T) { + input := []byte(`{ + "contents": [{ + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "args": {}}} + ] + }] + }`) + + err := ValidateGeminiFunctionCallPairing(input) + if err == nil { + t.Fatal("missing functionCall name should fail") + } + if !strings.Contains(err.Error(), "missing functionCall.name") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_RejectsIDMismatch(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "weather", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"id": "call-other", "name": "weather", "response": {}}} + ] + } + ] + }`) + + err := ValidateGeminiFunctionCallPairing(input) + if err == nil { + t.Fatal("id mismatch should fail") + } + if !strings.Contains(err.Error(), "does not match functionCall.id") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_RejectsMissingResponseName(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "weather", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"id": "call-1", "response": {}}} + ] + } + ] + }`) + + err := ValidateGeminiFunctionCallPairing(input) + if err == nil { + t.Fatal("missing response name should fail") + } + if !strings.Contains(err.Error(), "missing functionResponse.name") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateGeminiFunctionCallPairing_RejectsSameContentInterleaving(t *testing.T) { + input := []byte(`{ + "contents": [{ + "role": "model", + "parts": [ + {"functionCall": {"id": "call-1", "name": "weather", "args": {}}}, + {"functionResponse": {"id": "call-1", "name": "weather", "response": {}}} + ] + }] + }`) + + err := ValidateGeminiFunctionCallPairing(input) + if err == nil { + t.Fatal("same-content interleaving should fail") + } + if !strings.Contains(err.Error(), "must not be interleaved") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/signature/gpt_validation.go b/internal/signature/gpt_validation.go new file mode 100644 index 000000000..8cbd66281 --- /dev/null +++ b/internal/signature/gpt_validation.go @@ -0,0 +1,83 @@ +package signature + +import ( + "encoding/base64" + "fmt" + "strings" +) + +const MaxGPTReasoningSignatureLen = 32 * 1024 * 1024 + +type GPTReasoningSignatureInfo struct { + DecodedLen int + CiphertextLen int +} + +func IsValidGPTReasoningSignature(rawSignature string) bool { + _, err := InspectGPTReasoningSignature(rawSignature) + return err == nil +} + +// InspectGPTReasoningSignature validates the Fernet-like outer format used +// by GPT/Codex reasoning encrypted_content. This is only a transport-shape +// check; it does not prove decryptability. +func InspectGPTReasoningSignature(rawSignature string) (*GPTReasoningSignatureInfo, error) { + sig := strings.TrimSpace(rawSignature) + if sig == "" { + return nil, fmt.Errorf("empty GPT reasoning signature") + } + if len(sig) > MaxGPTReasoningSignatureLen { + return nil, fmt.Errorf("GPT reasoning signature exceeds maximum length (%d bytes)", MaxGPTReasoningSignatureLen) + } + if index, r, ok := firstInvalidGPTReasoningSignatureChar(sig); ok { + return nil, fmt.Errorf("invalid GPT reasoning signature: contains non-base64url character U+%04X at byte %d", r, index) + } + if !strings.HasPrefix(sig, "gAAAA") { + return nil, fmt.Errorf("invalid GPT reasoning signature: expected gAAAA prefix") + } + + decoded, err := decodeGPTReasoningSignature(sig) + if err != nil { + return nil, err + } + if len(decoded) < 73 { + return nil, fmt.Errorf("invalid GPT reasoning signature: decoded payload too short") + } + if decoded[0] != 0x80 { + return nil, fmt.Errorf("invalid GPT reasoning signature: expected version 0x80, got 0x%02x", decoded[0]) + } + + ciphertextLen := len(decoded) - 1 - 8 - 16 - 32 + if ciphertextLen <= 0 || ciphertextLen%16 != 0 { + return nil, fmt.Errorf("invalid GPT reasoning signature: ciphertext length %d is not a positive AES block multiple", ciphertextLen) + } + + return &GPTReasoningSignatureInfo{ + DecodedLen: len(decoded), + CiphertextLen: ciphertextLen, + }, nil +} + +func decodeGPTReasoningSignature(sig string) ([]byte, error) { + if decoded, err := base64.RawURLEncoding.DecodeString(sig); err == nil { + return decoded, nil + } + if decoded, err := base64.URLEncoding.DecodeString(sig); err == nil { + return decoded, nil + } + return nil, fmt.Errorf("invalid GPT reasoning signature: base64url decode failed") +} + +func firstInvalidGPTReasoningSignatureChar(sig string) (int, rune, bool) { + for index, r := range sig { + switch { + case r >= 'A' && r <= 'Z': + case r >= 'a' && r <= 'z': + case r >= '0' && r <= '9': + case r == '-' || r == '_' || r == '=': + default: + return index, r, true + } + } + return 0, 0, false +} diff --git a/internal/signature/gpt_validation_test.go b/internal/signature/gpt_validation_test.go new file mode 100644 index 000000000..21befa828 --- /dev/null +++ b/internal/signature/gpt_validation_test.go @@ -0,0 +1,35 @@ +package signature + +import ( + "encoding/base64" + "strings" + "testing" +) + +func testGPTReasoningSignature() string { + payload := make([]byte, 1+8+16+16+32) + payload[0] = 0x80 + for i := 9; i < len(payload); i++ { + payload[i] = byte(i) + } + return base64.RawURLEncoding.EncodeToString(payload) +} + +func TestDetectSignatureProvider_GPTReasoning(t *testing.T) { + if got := DetectSignatureProvider(testGPTReasoningSignature()); got != SignatureProviderGPT { + t.Fatalf("DetectSignatureProvider(GPT) = %q, want %q", got, SignatureProviderGPT) + } +} + +func TestInspectGPTReasoningSignatureRejectsUnicodeEllipsis(t *testing.T) { + sig := testGPTReasoningSignature() + polluted := sig[:20] + string(rune(0x2026)) + sig[20:] + + _, err := InspectGPTReasoningSignature(polluted) + if err == nil { + t.Fatal("expected invalid GPT reasoning signature") + } + if !strings.Contains(err.Error(), "non-base64url character U+2026") { + t.Fatalf("error = %q, want U+2026 base64url detail", err.Error()) + } +} diff --git a/internal/signature/provider_compatibility.go b/internal/signature/provider_compatibility.go new file mode 100644 index 000000000..6cdb896fb --- /dev/null +++ b/internal/signature/provider_compatibility.go @@ -0,0 +1,283 @@ +package signature + +import "strings" + +type SignatureProvider string + +const ( + SignatureProviderUnknown SignatureProvider = "unknown" + SignatureProviderClaude SignatureProvider = "claude" + SignatureProviderGemini SignatureProvider = "gemini" + SignatureProviderGeminiBypass SignatureProvider = "gemini_bypass" + SignatureProviderGPT SignatureProvider = "gpt" +) + +type SignatureBlockKind string + +const ( + SignatureBlockKindUnknown SignatureBlockKind = "unknown" + SignatureBlockKindClaudeThinking SignatureBlockKind = "claude_thinking" + SignatureBlockKindGeminiModelPart SignatureBlockKind = "gemini_model_part" + SignatureBlockKindGeminiFunctionCall SignatureBlockKind = "gemini_function_call" + SignatureBlockKindGPTReasoning SignatureBlockKind = "gpt_reasoning" +) + +type SignatureCompatibilityAction string + +const ( + SignatureActionPreserve SignatureCompatibilityAction = "preserve" + SignatureActionDropBlock SignatureCompatibilityAction = "drop_block" + SignatureActionDropSignature SignatureCompatibilityAction = "drop_signature" + SignatureActionReplaceWithGeminiBypass SignatureCompatibilityAction = "replace_with_gemini_bypass" + SignatureActionNoCompatibleReplacement SignatureCompatibilityAction = "no_compatible_replacement" +) + +type SignatureCompatibilityDecision struct { + TargetProvider SignatureProvider + DetectedProvider SignatureProvider + BlockKind SignatureBlockKind + Compatible bool + Action SignatureCompatibilityAction + ReplacementSignature string + NormalizedSignature string + Reason string +} + +// SignatureProviderFromModelName maps common model names to the provider family +// whose signed history can be safely replayed for that model. +func SignatureProviderFromModelName(modelName string) SignatureProvider { + lower := strings.ToLower(strings.TrimSpace(modelName)) + switch { + case strings.Contains(lower, "claude"): + return SignatureProviderClaude + case strings.Contains(lower, "gemini"): + return SignatureProviderGemini + case strings.Contains(lower, "gpt"), + strings.Contains(lower, "openai"), + strings.Contains(lower, "codex"), + strings.HasPrefix(lower, "o1"), + strings.HasPrefix(lower, "o3"), + strings.HasPrefix(lower, "o4"): + return SignatureProviderGPT + default: + return SignatureProviderUnknown + } +} + +// DetectSignatureProvider classifies the provider family that can replay +// rawSignature. It intentionally uses Claude strict validation before Gemini +// detection because Gemini 3 signatures also decode from an E-prefixed base64 +// string and can look Claude-like under shallow prefix checks. +func DetectSignatureProvider(rawSignature string) SignatureProvider { + return DetectSignatureProviderForBlock(rawSignature, SignatureBlockKindUnknown) +} + +// DetectSignatureProviderForBlock classifies rawSignature with block-kind +// context. UUID-shaped payloads are deliberately not classified as replay-safe +// provider signatures; callers targeting Gemini should replace them with the +// bypass sentinel. +func DetectSignatureProviderForBlock(rawSignature string, blockKind SignatureBlockKind) SignatureProvider { + sig := strings.TrimSpace(rawSignature) + if sig == "" { + return SignatureProviderUnknown + } + + if prefixedProvider, unprefixed, ok := SplitSignatureProviderPrefix(sig); ok { + switch prefixedProvider { + case SignatureProviderGemini: + if IsGeminiThoughtSignatureBypass(unprefixed) { + return SignatureProviderGeminiBypass + } + if isRecognizedGeminiProviderSignature(unprefixed, blockKind) { + return SignatureProviderGemini + } + case SignatureProviderClaude: + if IsValidClaudeThinkingSignature(unprefixed, ClaudeSignatureValidationOptions{Strict: true}) { + return SignatureProviderClaude + } + case SignatureProviderGPT: + if IsValidGPTReasoningSignature(unprefixed) { + return SignatureProviderGPT + } + } + return SignatureProviderUnknown + } + if strings.Contains(sig, "#") { + return SignatureProviderUnknown + } + + if IsGeminiThoughtSignatureBypass(sig) { + return SignatureProviderGeminiBypass + } + if IsValidGPTReasoningSignature(sig) { + return SignatureProviderGPT + } + if IsValidClaudeThinkingSignature(sig, ClaudeSignatureValidationOptions{Strict: true}) { + return SignatureProviderClaude + } + if isRecognizedGeminiProviderSignature(sig, blockKind) { + return SignatureProviderGemini + } + return SignatureProviderUnknown +} + +func IsSignatureCompatibleWithProvider(targetProvider SignatureProvider, rawSignature string) bool { + decision := DecideSignatureCompatibility(targetProvider, rawSignature, SignatureBlockKindUnknown) + return decision.Compatible +} + +// DecideSignatureCompatibility returns the safe handling policy for replaying a +// signed block into targetProvider. +func DecideSignatureCompatibility(targetProvider SignatureProvider, rawSignature string, blockKind SignatureBlockKind) SignatureCompatibilityDecision { + targetProvider = normalizeSignatureTargetProvider(targetProvider) + if blockKind == "" { + blockKind = SignatureBlockKindUnknown + } + + detected := DetectSignatureProviderForBlock(rawSignature, blockKind) + decision := SignatureCompatibilityDecision{ + TargetProvider: targetProvider, + DetectedProvider: detected, + BlockKind: blockKind, + } + + if signatureProviderMatchesTarget(targetProvider, detected) { + decision.Compatible = true + decision.Action = SignatureActionPreserve + decision.NormalizedSignature = normalizeCompatibleSignatureForProvider(targetProvider, rawSignature, blockKind) + decision.Reason = "signature provider matches target provider" + return decision + } + + decision.Compatible = false + switch targetProvider { + case SignatureProviderGemini: + if blockKind == SignatureBlockKindGeminiFunctionCall || blockKind == SignatureBlockKindGeminiModelPart || blockKind == SignatureBlockKindUnknown { + decision.Action = SignatureActionReplaceWithGeminiBypass + decision.ReplacementSignature = GeminiSkipThoughtSignatureValidator + decision.Reason = "Gemini can bypass synthetic or incompatible model-part signatures with the documented sentinel" + return decision + } + decision.Action = SignatureActionDropBlock + decision.Reason = "signature is not compatible with Gemini and this block is not a bypass-safe Gemini model part" + case SignatureProviderClaude: + decision.Action = SignatureActionDropBlock + decision.Reason = "Claude has no cross-provider bypass sentinel for thinking blocks" + case SignatureProviderGPT: + decision.Action = SignatureActionDropBlock + decision.Reason = "GPT reasoning encrypted_content cannot be synthesized from another provider signature" + default: + decision.Action = SignatureActionNoCompatibleReplacement + decision.Reason = "unknown target provider" + } + return decision +} + +func SplitSignatureProviderPrefix(rawSignature string) (SignatureProvider, string, bool) { + prefix, rest, ok := strings.Cut(strings.TrimSpace(rawSignature), "#") + if !ok { + return SignatureProviderUnknown, rawSignature, false + } + provider := SignatureProviderFromCachePrefix(prefix) + if provider == SignatureProviderUnknown { + return SignatureProviderUnknown, rawSignature, false + } + return provider, strings.TrimSpace(rest), true +} + +// SignatureProviderFromCachePrefix maps this repo's explicit provider-prefix +// envelope to a provider family. This is intentionally stricter than +// SignatureProviderFromModelName so arbitrary model names such as +// "claude-cache#..." cannot be mistaken for trusted provider provenance. +func SignatureProviderFromCachePrefix(prefix string) SignatureProvider { + switch strings.ToLower(strings.TrimSpace(prefix)) { + case "claude", "anthropic": + return SignatureProviderClaude + case "gemini", "google": + return SignatureProviderGemini + case "openai", "gpt", "codex": + return SignatureProviderGPT + default: + return SignatureProviderUnknown + } +} + +// SignaturePayloadWithoutProviderPrefix strips this repo's provider cache prefix +// when present. The returned string is the value that should be replayed to an +// upstream provider. +func SignaturePayloadWithoutProviderPrefix(rawSignature string) string { + if _, unprefixed, ok := SplitSignatureProviderPrefix(rawSignature); ok { + return unprefixed + } + return strings.TrimSpace(rawSignature) +} + +// CompatibleSignatureForProvider returns a replayable provider-native signature +// for targetProvider. It strips this repo's provider prefix and normalizes +// Claude signatures to the format expected by the target when possible. +func CompatibleSignatureForProvider(targetProvider SignatureProvider, rawSignature string) (string, bool) { + return CompatibleSignatureForProviderBlock(targetProvider, rawSignature, SignatureBlockKindUnknown) +} + +// CompatibleSignatureForProviderBlock returns a replayable provider-native +// signature for targetProvider when the source block kind is known. +func CompatibleSignatureForProviderBlock(targetProvider SignatureProvider, rawSignature string, blockKind SignatureBlockKind) (string, bool) { + decision := DecideSignatureCompatibility(targetProvider, rawSignature, blockKind) + if !decision.Compatible || decision.NormalizedSignature == "" { + return "", false + } + return decision.NormalizedSignature, true +} + +func normalizeSignatureTargetProvider(provider SignatureProvider) SignatureProvider { + switch provider { + case SignatureProviderGeminiBypass: + return SignatureProviderGemini + default: + return provider + } +} + +func signatureProviderMatchesTarget(target, detected SignatureProvider) bool { + switch target { + case SignatureProviderGemini: + return detected == SignatureProviderGemini || detected == SignatureProviderGeminiBypass + case SignatureProviderClaude: + return detected == SignatureProviderClaude + case SignatureProviderGPT: + return detected == SignatureProviderGPT + default: + return false + } +} + +func normalizeCompatibleSignatureForProvider(targetProvider SignatureProvider, rawSignature string, blockKind SignatureBlockKind) string { + payload := SignaturePayloadWithoutProviderPrefix(rawSignature) + switch normalizeSignatureTargetProvider(targetProvider) { + case SignatureProviderClaude: + normalized, err := NormalizeClaudeThinkingSignature(payload) + if err != nil { + return "" + } + return normalized + case SignatureProviderGemini: + if IsGeminiThoughtSignatureBypass(payload) { + return payload + } + if isRecognizedGeminiProviderSignature(payload, blockKind) { + return payload + } + case SignatureProviderGPT: + if IsValidGPTReasoningSignature(payload) { + return payload + } + } + return "" +} + +func isRecognizedGeminiProviderSignature(rawSignature string, blockKind SignatureBlockKind) bool { + if IsValidGeminiThoughtSignature(rawSignature, GeminiThoughtSignatureValidationOptions{RequireKnownEnvelope: true}) { + return true + } + return false +} diff --git a/internal/signature/provider_compatibility_test.go b/internal/signature/provider_compatibility_test.go new file mode 100644 index 000000000..5768d11cb --- /dev/null +++ b/internal/signature/provider_compatibility_test.go @@ -0,0 +1,248 @@ +package signature + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/tidwall/gjson" + "google.golang.org/protobuf/encoding/protowire" +) + +func testClaudeThinkingSignature() string { + channelBlock := []byte{} + channelBlock = protowire.AppendTag(channelBlock, 1, protowire.VarintType) + channelBlock = protowire.AppendVarint(channelBlock, 12) + channelBlock = protowire.AppendTag(channelBlock, 2, protowire.VarintType) + channelBlock = protowire.AppendVarint(channelBlock, 2) + channelBlock = protowire.AppendTag(channelBlock, 6, protowire.BytesType) + channelBlock = protowire.AppendString(channelBlock, "claude-sonnet-4-6") + + container := []byte{} + container = protowire.AppendTag(container, 1, protowire.BytesType) + container = protowire.AppendBytes(container, channelBlock) + + payload := []byte{} + payload = protowire.AppendTag(payload, 2, protowire.BytesType) + payload = protowire.AppendBytes(payload, container) + payload = protowire.AppendTag(payload, 3, protowire.VarintType) + payload = protowire.AppendVarint(payload, 1) + return base64.StdEncoding.EncodeToString(payload) +} + +func TestDetectSignatureProvider_UsesProviderPrefix(t *testing.T) { + claudeSig := "claude#" + testClaudeThinkingSignature() + if got := DetectSignatureProvider(claudeSig); got != SignatureProviderClaude { + t.Fatalf("DetectSignatureProvider(claude#...) = %q, want %q", got, SignatureProviderClaude) + } + + geminiSig := "gemini#" + testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39}) + if got := DetectSignatureProvider(geminiSig); got != SignatureProviderGemini { + t.Fatalf("DetectSignatureProvider(gemini#...) = %q, want %q", got, SignatureProviderGemini) + } +} + +func TestDetectSignatureProvider_RejectsMisleadingClaudePrefix(t *testing.T) { + mislabeledGeminiSig := "claude#" + testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39}) + if got := DetectSignatureProvider(mislabeledGeminiSig); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(mislabeled claude#Gemini) = %q, want %q", got, SignatureProviderUnknown) + } +} + +func TestDetectSignatureProvider_Gemini3EPrefixDoesNotLookClaude(t *testing.T) { + // This byte shape base64-encodes with an E prefix but is a Gemini field-2 + // envelope, not a Claude thinking-signature tree. + geminiSig := testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39, 0xd6, 0xc7, 0x34}) + if !strings.HasPrefix(geminiSig, "E") { + t.Fatalf("test signature should start with E, got %q", geminiSig[:1]) + } + if got := DetectSignatureProvider(geminiSig); got != SignatureProviderGemini { + t.Fatalf("DetectSignatureProvider(Gemini E-prefix) = %q, want %q", got, SignatureProviderGemini) + } +} + +func TestDetectSignatureProvider_DoesNotClassifyArbitraryBase64AsGemini(t *testing.T) { + opaque := testGeminiThoughtSignature([]byte{0x45, 0x12}) + if got := DetectSignatureProvider(opaque); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(arbitrary base64) = %q, want %q", got, SignatureProviderUnknown) + } +} + +func TestGeminiASCIIUUIDSignatureUsesBypass(t *testing.T) { + plainUUID := "e24830a7-5cd6-42fe-998b-ee539e72b9c3" + sig := testGeminiThoughtSignature([]byte(plainUUID)) + + if got := DetectSignatureProvider(plainUUID); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(plain UUID) = %q, want %q", got, SignatureProviderUnknown) + } + if got := DetectSignatureProvider("gemini#" + plainUUID); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(gemini#plain UUID) = %q, want %q", got, SignatureProviderUnknown) + } + + if got := DetectSignatureProvider(sig); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(UUID) = %q, want %q", got, SignatureProviderUnknown) + } + if got := DetectSignatureProvider("gemini#" + sig); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(gemini#UUID) = %q, want %q", got, SignatureProviderUnknown) + } + if got := DetectSignatureProviderForBlock(sig, SignatureBlockKindGeminiFunctionCall); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProviderForBlock(UUID tool call) = %q, want %q", got, SignatureProviderUnknown) + } + if _, ok := CompatibleSignatureForProvider(SignatureProviderGemini, sig); ok { + t.Fatal("UUID signature should not be compatible") + } + if normalized, ok := CompatibleSignatureForProviderBlock(SignatureProviderGemini, sig, SignatureBlockKindGeminiFunctionCall); ok || normalized != "" { + t.Fatalf("UUID tool-call signature normalized=%q ok=%v, want empty and false", normalized, ok) + } + decision := DecideSignatureCompatibility(SignatureProviderGemini, sig, SignatureBlockKindGeminiFunctionCall) + if decision.Action != SignatureActionReplaceWithGeminiBypass { + t.Fatalf("function-call UUID action = %q, want %q", decision.Action, SignatureActionReplaceWithGeminiBypass) + } + if decision.ReplacementSignature != GeminiSkipThoughtSignatureValidator { + t.Fatalf("function-call UUID replacement = %q, want %q", decision.ReplacementSignature, GeminiSkipThoughtSignatureValidator) + } + decision = DecideSignatureCompatibility(SignatureProviderGemini, sig, SignatureBlockKindGeminiModelPart) + if decision.Action != SignatureActionReplaceWithGeminiBypass { + t.Fatalf("model-part UUID action = %q, want %q", decision.Action, SignatureActionReplaceWithGeminiBypass) + } +} + +func TestGeminiWrappedUUIDFunctionCallSignatureIsUnknown(t *testing.T) { + sig := testGemini3ThoughtSignature([]byte("e24830a7-5cd6-42fe-998b-ee539e72b9c3")) + + if got := DetectSignatureProvider(sig); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(wrapped UUID) = %q, want %q", got, SignatureProviderUnknown) + } + if got := DetectSignatureProviderForBlock(sig, SignatureBlockKindGeminiFunctionCall); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProviderForBlock(wrapped UUID tool call) = %q, want %q", got, SignatureProviderUnknown) + } + if normalized, ok := CompatibleSignatureForProviderBlock(SignatureProviderGemini, sig, SignatureBlockKindGeminiFunctionCall); ok || normalized != "" { + t.Fatalf("wrapped UUID tool-call signature normalized=%q ok=%v, want empty and false", normalized, ok) + } + decision := DecideSignatureCompatibility(SignatureProviderGemini, sig, SignatureBlockKindGeminiFunctionCall) + if decision.Action != SignatureActionReplaceWithGeminiBypass { + t.Fatalf("function-call wrapped UUID action = %q, want %q", decision.Action, SignatureActionReplaceWithGeminiBypass) + } + if decision.ReplacementSignature != GeminiSkipThoughtSignatureValidator { + t.Fatalf("function-call wrapped UUID replacement = %q, want %q", decision.ReplacementSignature, GeminiSkipThoughtSignatureValidator) + } + decision = DecideSignatureCompatibility(SignatureProviderGemini, sig, SignatureBlockKindGeminiModelPart) + if decision.Action != SignatureActionReplaceWithGeminiBypass { + t.Fatalf("model-part wrapped UUID action = %q, want %q", decision.Action, SignatureActionReplaceWithGeminiBypass) + } +} + +func TestCompatibleSignatureForProvider_StripsGeminiPrefix(t *testing.T) { + sig := testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39}) + normalized, ok := CompatibleSignatureForProvider(SignatureProviderGemini, "gemini#"+sig) + if !ok { + t.Fatal("gemini-prefixed signature should be compatible with Gemini") + } + if normalized != sig { + t.Fatalf("normalized = %q, want %q", normalized, sig) + } +} + +func TestSplitSignatureProviderPrefix_UsesStrictProviderAliases(t *testing.T) { + gptSig := "gpt#" + testGPTReasoningSignature() + if got := DetectSignatureProvider(gptSig); got != SignatureProviderGPT { + t.Fatalf("DetectSignatureProvider(gpt#...) = %q, want %q", got, SignatureProviderGPT) + } + + mislabeledPrefix := "claude-cache#" + testClaudeThinkingSignature() + if _, _, ok := SplitSignatureProviderPrefix(mislabeledPrefix); ok { + t.Fatal("claude-cache# should not be accepted as an explicit provider prefix") + } + if got := DetectSignatureProvider(mislabeledPrefix); got != SignatureProviderUnknown { + t.Fatalf("DetectSignatureProvider(claude-cache#...) = %q, want %q", got, SignatureProviderUnknown) + } +} + +func TestDecideSignatureCompatibility_GeminiFunctionCallUsesBypass(t *testing.T) { + decision := DecideSignatureCompatibility(SignatureProviderGemini, "claude#"+testClaudeThinkingSignature(), SignatureBlockKindGeminiFunctionCall) + if decision.Action != SignatureActionReplaceWithGeminiBypass { + t.Fatalf("Action = %q, want %q", decision.Action, SignatureActionReplaceWithGeminiBypass) + } + if decision.ReplacementSignature != GeminiSkipThoughtSignatureValidator { + t.Fatalf("ReplacementSignature = %q, want %q", decision.ReplacementSignature, GeminiSkipThoughtSignatureValidator) + } +} + +func TestSanitizeClaudeMessagesSignaturesForModel_NormalizesSameProviderClaude(t *testing.T) { + nativeSig := testClaudeThinkingSignature() + sig := "claude#" + nativeSig + input := []byte(`{"model":"claude-sonnet","messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"keep","signature":"` + sig + `"},{"type":"text","text":"answer"}]}]}`) + expectedSig, err := NormalizeClaudeThinkingSignature(nativeSig) + if err != nil { + t.Fatalf("NormalizeClaudeThinkingSignature failed: %v", err) + } + + output, report := SanitizeClaudeMessagesSignaturesForModel(input, "claude-sonnet-4-5") + if report.Preserved != 1 || report.DroppedBlocks != 0 { + t.Fatalf("unexpected report: %+v", report) + } + if got := gjson.GetBytes(output, "messages.0.content.0.signature").String(); got != expectedSig { + t.Fatalf("signature = %q, want normalized %q", got, expectedSig) + } +} + +func TestSanitizeClaudeMessagesSignaturesForModel_DropsClaudeThinkingForGemini(t *testing.T) { + sig := "claude#" + testClaudeThinkingSignature() + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop","signature":"` + sig + `"},{"type":"text","text":"answer"}]}]}`) + + output, report := SanitizeClaudeMessagesSignaturesForModel(input, "gemini-3.5-flash") + if report.DroppedBlocks != 1 { + t.Fatalf("DroppedBlocks = %d, want 1; report=%+v", report.DroppedBlocks, report) + } + content := gjson.GetBytes(output, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("content length = %d, want 1: %s", len(content), output) + } + if got := content[0].Get("text").String(); got != "answer" { + t.Fatalf("remaining text = %q, want answer", got) + } +} + +func TestSanitizeClaudeMessagesSignaturesForModel_PreservesGeminiThinkingForGemini(t *testing.T) { + nativeSig := testGemini3ThoughtSignature([]byte{0x01, 0x0c, 0x39}) + sig := "gemini#" + nativeSig + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"keep","signature":"` + sig + `"},{"type":"text","text":"answer"}]}]}`) + + output, report := SanitizeClaudeMessagesSignaturesForModel(input, "gemini-3.5-flash") + if report.Preserved != 1 || report.DroppedBlocks != 0 { + t.Fatalf("unexpected report: %+v", report) + } + if got := gjson.GetBytes(output, "messages.0.content.0.signature").String(); got != nativeSig { + t.Fatalf("signature = %q, want normalized %q", got, nativeSig) + } +} + +func TestSanitizeClaudeMessagesSignaturesForModel_PreservesGPTForGPT(t *testing.T) { + sig := testGPTReasoningSignature() + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"keep","signature":"` + sig + `"},{"type":"text","text":"answer"}]}]}`) + + output, report := SanitizeClaudeMessagesSignaturesForModel(input, "gpt-5.2") + if report.Preserved != 1 || report.DroppedBlocks != 0 { + t.Fatalf("unexpected report: %+v", report) + } + if got := gjson.GetBytes(output, "messages.0.content.0.signature").String(); got != sig { + t.Fatalf("signature = %q, want preserved %q", got, sig) + } +} + +func TestSanitizeClaudeMessagesSignaturesForModel_DropsEmptyAssistantMessage(t *testing.T) { + sig := "claude#" + testClaudeThinkingSignature() + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop","signature":"` + sig + `"}]},{"role":"user","content":[{"type":"text","text":"next"}]}]}`) + + output, report := SanitizeClaudeMessagesSignaturesForModel(input, "gpt-5.2") + if report.DroppedBlocks != 1 { + t.Fatalf("DroppedBlocks = %d, want 1", report.DroppedBlocks) + } + messages := gjson.GetBytes(output, "messages").Array() + if len(messages) != 1 { + t.Fatalf("messages length = %d, want 1: %s", len(messages), output) + } + if got := messages[0].Get("role").String(); got != "user" { + t.Fatalf("remaining role = %q, want user", got) + } +} diff --git a/internal/translator/antigravity/claude/signature_validation.go b/internal/translator/antigravity/claude/signature_validation.go index f82fc2e36..f0acbf8e7 100644 --- a/internal/translator/antigravity/claude/signature_validation.go +++ b/internal/translator/antigravity/claude/signature_validation.go @@ -1,448 +1,42 @@ -// Claude thinking signature validation for Antigravity bypass mode. -// -// Spec reference: SIGNATURE-CHANNEL-SPEC.md -// -// # Encoding Detection (Spec §3) -// -// Claude signatures use base64 encoding in one or two layers. The raw string's -// first character determines the encoding depth — this is mathematically equivalent -// to the spec's "decode first, check byte" approach: -// -// - 'E' prefix → single-layer: payload[0]==0x12, first 6 bits = 000100 = base64 index 4 = 'E' -// - 'R' prefix → double-layer: inner[0]=='E' (0x45), first 6 bits = 010001 = base64 index 17 = 'R' -// -// All valid signatures are normalized to R-form (double-layer base64) before -// sending to the Antigravity backend. -// -// # Protobuf Structure (Spec §4.1, §4.2) — strict mode only -// -// After base64 decoding to raw bytes (first byte must be 0x12): -// -// Top-level protobuf -// ├── Field 2 (bytes): container ← extractBytesField(payload, 2) -// │ ├── Field 1 (bytes): channel block ← extractBytesField(container, 1) -// │ │ ├── Field 1 (varint): channel_id [required] → routing_class (11 | 12) -// │ │ ├── Field 2 (varint): infra [optional] → infrastructure_class (aws=1 | google=2) -// │ │ ├── Field 3 (varint): version=2 [skipped] -// │ │ ├── Field 5 (bytes): ECDSA sig [skipped, per Spec §11] -// │ │ ├── Field 6 (bytes): model_text [optional] → schema_features -// │ │ └── Field 7 (varint): unknown [optional] → schema_features -// │ ├── Field 2 (bytes): nonce 12B [skipped] -// │ ├── Field 3 (bytes): session 12B [skipped] -// │ ├── Field 4 (bytes): SHA-384 48B [skipped] -// │ └── Field 5 (bytes): metadata [skipped, per Spec §11] -// └── Field 3 (varint): =1 [skipped] -// -// # Output Dimensions (Spec §8) -// -// routing_class: routing_class_11 | routing_class_12 | unknown -// infrastructure_class: infra_default (absent) | infra_aws (1) | infra_google (2) | infra_unknown -// schema_features: compact_schema (len 70-72, no f6/f7) | extended_model_tagged_schema (f6 exists) | unknown -// legacy_route_hint: only for ch=11 — legacy_default_group | legacy_aws_group | legacy_vertex_direct/proxy -// -// # Compatibility -// -// Verified against all confirmed spec samples (Anthropic Max 20x, Azure, Vertex, -// Bedrock) and legacy ch=11 signatures. Both single-layer (E) and double-layer (R) -// encodings are supported. Historical cache-mode 'modelGroup#' prefixes are stripped. +// Claude thinking signature validation wrappers for Antigravity bypass mode. package claude import ( - "encoding/base64" - "fmt" - "strings" - "unicode/utf8" - "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "google.golang.org/protobuf/encoding/protowire" + "github.com/router-for-me/CLIProxyAPI/v7/internal/signature" ) -const maxBypassSignatureLen = 32 * 1024 * 1024 +const maxBypassSignatureLen = signature.MaxClaudeThinkingSignatureLen -type claudeSignatureTree struct { - EncodingLayers int - ChannelID uint64 - Field2 *uint64 - RoutingClass string - InfrastructureClass string - SchemaFeatures string - ModelText string - LegacyRouteHint string - HasField7 bool -} +type claudeSignatureTree = signature.ClaudeSignatureTree -// StripInvalidSignatureThinkingBlocks removes thinking blocks whose signatures -// are empty or not valid Claude format (must start with 'E' or 'R' after -// stripping any cache prefix). These come from proxy-generated responses -// (Antigravity/Gemini) where no real Claude signature exists. +// StripEmptySignatureThinkingBlocks removes thinking blocks whose signatures +// are empty or not valid Claude thinking signatures. These usually come from +// proxy-generated responses where no real Claude signature exists. func StripEmptySignatureThinkingBlocks(payload []byte) []byte { - messages := gjson.GetBytes(payload, "messages") - if !messages.IsArray() { - return payload - } - modified := false - for i, msg := range messages.Array() { - content := msg.Get("content") - if !content.IsArray() { - continue - } - var kept []string - stripped := false - for _, part := range content.Array() { - if part.Get("type").String() == "thinking" && !hasValidClaudeSignature(part.Get("signature").String()) { - stripped = true - continue - } - kept = append(kept, part.Raw) - } - if stripped { - modified = true - if len(kept) == 0 { - payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("[]")) - } else { - payload, _ = sjson.SetRawBytes(payload, fmt.Sprintf("messages.%d.content", i), []byte("["+strings.Join(kept, ",")+"]")) - } - } - } - if !modified { - return payload - } - return payload -} - -// hasValidClaudeSignature returns true if sig looks like a real Claude thinking -// signature: non-empty and starts with 'E' or 'R' (after stripping optional -// cache prefix like "modelGroup#"). -func hasValidClaudeSignature(sig string) bool { - sig = strings.TrimSpace(sig) - if sig == "" { - return false - } - if idx := strings.IndexByte(sig, '#'); idx >= 0 { - sig = strings.TrimSpace(sig[idx+1:]) - } - if sig == "" { - return false - } - return sig[0] == 'E' || sig[0] == 'R' + return signature.StripInvalidClaudeThinkingBlocks(payload, signature.ClaudeSignatureValidationOptions{PrefixOnly: true}) } func ValidateClaudeBypassSignatures(inputRawJSON []byte) error { - messages := gjson.GetBytes(inputRawJSON, "messages") - if !messages.IsArray() { - return nil - } - - messageResults := messages.Array() - for i := 0; i < len(messageResults); i++ { - contentResults := messageResults[i].Get("content") - if !contentResults.IsArray() { - continue - } - parts := contentResults.Array() - for j := 0; j < len(parts); j++ { - part := parts[j] - if part.Get("type").String() != "thinking" { - continue - } - - rawSignature := strings.TrimSpace(part.Get("signature").String()) - if rawSignature == "" { - return fmt.Errorf("messages[%d].content[%d]: missing thinking signature", i, j) - } - - if _, err := normalizeClaudeBypassSignature(rawSignature); err != nil { - return fmt.Errorf("messages[%d].content[%d]: %w", i, j, err) - } - } - } - - return nil + return signature.ValidateClaudeThinkingSignatures(inputRawJSON, claudeBypassSignatureValidationOptions()) } func normalizeClaudeBypassSignature(rawSignature string) (string, error) { - sig := strings.TrimSpace(rawSignature) - if sig == "" { - return "", fmt.Errorf("empty signature") - } - - if idx := strings.IndexByte(sig, '#'); idx >= 0 { - sig = strings.TrimSpace(sig[idx+1:]) - } - - if sig == "" { - return "", fmt.Errorf("empty signature after stripping prefix") - } - - if len(sig) > maxBypassSignatureLen { - return "", fmt.Errorf("signature exceeds maximum length (%d bytes)", maxBypassSignatureLen) - } - - switch sig[0] { - case 'R': - if err := validateDoubleLayerSignature(sig); err != nil { - return "", err - } - return sig, nil - case 'E': - if err := validateSingleLayerSignature(sig); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString([]byte(sig)), nil - default: - return "", fmt.Errorf("invalid signature: expected 'E' or 'R' prefix, got %q", string(sig[0])) - } -} - -func validateDoubleLayerSignature(sig string) error { - decoded, err := base64.StdEncoding.DecodeString(sig) - if err != nil { - return fmt.Errorf("invalid double-layer signature: base64 decode failed: %w", err) - } - if len(decoded) == 0 { - return fmt.Errorf("invalid double-layer signature: empty after decode") - } - if decoded[0] != 'E' { - return fmt.Errorf("invalid double-layer signature: inner does not start with 'E', got 0x%02x", decoded[0]) - } - return validateSingleLayerSignatureContent(string(decoded), 2) -} - -func validateSingleLayerSignature(sig string) error { - return validateSingleLayerSignatureContent(sig, 1) -} - -func validateSingleLayerSignatureContent(sig string, encodingLayers int) error { - decoded, err := base64.StdEncoding.DecodeString(sig) - if err != nil { - return fmt.Errorf("invalid single-layer signature: base64 decode failed: %w", err) - } - if len(decoded) == 0 { - return fmt.Errorf("invalid single-layer signature: empty after decode") - } - if decoded[0] != 0x12 { - return fmt.Errorf("invalid Claude signature: expected first byte 0x12, got 0x%02x", decoded[0]) - } - if !cache.SignatureBypassStrictMode() { - return nil - } - _, err = inspectClaudeSignaturePayload(decoded, encodingLayers) - return err + return signature.NormalizeClaudeThinkingSignature(rawSignature, claudeBypassSignatureValidationOptions()) } func inspectDoubleLayerSignature(sig string) (*claudeSignatureTree, error) { - decoded, err := base64.StdEncoding.DecodeString(sig) - if err != nil { - return nil, fmt.Errorf("invalid double-layer signature: base64 decode failed: %w", err) - } - if len(decoded) == 0 { - return nil, fmt.Errorf("invalid double-layer signature: empty after decode") - } - if decoded[0] != 'E' { - return nil, fmt.Errorf("invalid double-layer signature: inner does not start with 'E', got 0x%02x", decoded[0]) - } - return inspectSingleLayerSignatureWithLayers(string(decoded), 2) + return signature.InspectClaudeDoubleLayerSignature(sig) } func inspectSingleLayerSignature(sig string) (*claudeSignatureTree, error) { - return inspectSingleLayerSignatureWithLayers(sig, 1) -} - -func inspectSingleLayerSignatureWithLayers(sig string, encodingLayers int) (*claudeSignatureTree, error) { - decoded, err := base64.StdEncoding.DecodeString(sig) - if err != nil { - return nil, fmt.Errorf("invalid single-layer signature: base64 decode failed: %w", err) - } - if len(decoded) == 0 { - return nil, fmt.Errorf("invalid single-layer signature: empty after decode") - } - return inspectClaudeSignaturePayload(decoded, encodingLayers) + return signature.InspectClaudeSingleLayerSignature(sig) } func inspectClaudeSignaturePayload(payload []byte, encodingLayers int) (*claudeSignatureTree, error) { - if len(payload) == 0 { - return nil, fmt.Errorf("invalid Claude signature: empty payload") - } - if payload[0] != 0x12 { - return nil, fmt.Errorf("invalid Claude signature: expected first byte 0x12, got 0x%02x", payload[0]) - } - container, err := extractBytesField(payload, 2, "top-level protobuf") - if err != nil { - return nil, err - } - channelBlock, err := extractBytesField(container, 1, "Claude Field 2 container") - if err != nil { - return nil, err - } - return inspectClaudeChannelBlock(channelBlock, encodingLayers) + return signature.InspectClaudeSignaturePayload(payload, encodingLayers) } -func inspectClaudeChannelBlock(channelBlock []byte, encodingLayers int) (*claudeSignatureTree, error) { - tree := &claudeSignatureTree{ - EncodingLayers: encodingLayers, - RoutingClass: "unknown", - InfrastructureClass: "infra_unknown", - SchemaFeatures: "unknown_schema_features", - } - haveChannelID := false - hasField6 := false - hasField7 := false - - err := walkProtobufFields(channelBlock, func(num protowire.Number, typ protowire.Type, raw []byte) error { - switch num { - case 1: - if typ != protowire.VarintType { - return fmt.Errorf("invalid Claude signature: Field 2.1.1 channel_id must be varint") - } - channelID, err := decodeVarintField(raw, "Field 2.1.1 channel_id") - if err != nil { - return err - } - tree.ChannelID = channelID - haveChannelID = true - case 2: - if typ != protowire.VarintType { - return fmt.Errorf("invalid Claude signature: Field 2.1.2 field2 must be varint") - } - field2, err := decodeVarintField(raw, "Field 2.1.2 field2") - if err != nil { - return err - } - tree.Field2 = &field2 - case 6: - if typ != protowire.BytesType { - return fmt.Errorf("invalid Claude signature: Field 2.1.6 model_text must be bytes") - } - modelBytes, err := decodeBytesField(raw, "Field 2.1.6 model_text") - if err != nil { - return err - } - if !utf8.Valid(modelBytes) { - return fmt.Errorf("invalid Claude signature: Field 2.1.6 model_text is not valid UTF-8") - } - tree.ModelText = string(modelBytes) - hasField6 = true - case 7: - if typ != protowire.VarintType { - return fmt.Errorf("invalid Claude signature: Field 2.1.7 must be varint") - } - if _, err := decodeVarintField(raw, "Field 2.1.7"); err != nil { - return err - } - hasField7 = true - tree.HasField7 = true - } - return nil - }) - if err != nil { - return nil, err - } - if !haveChannelID { - return nil, fmt.Errorf("invalid Claude signature: missing Field 2.1.1 channel_id") - } - - switch tree.ChannelID { - case 11: - tree.RoutingClass = "routing_class_11" - case 12: - tree.RoutingClass = "routing_class_12" - } - - if tree.Field2 == nil { - tree.InfrastructureClass = "infra_default" - } else { - switch *tree.Field2 { - case 1: - tree.InfrastructureClass = "infra_aws" - case 2: - tree.InfrastructureClass = "infra_google" - default: - tree.InfrastructureClass = "infra_unknown" - } - } - - switch { - case hasField6: - tree.SchemaFeatures = "extended_model_tagged_schema" - case !hasField6 && !hasField7 && len(channelBlock) >= 70 && len(channelBlock) <= 72: - tree.SchemaFeatures = "compact_schema" - } - - if tree.ChannelID == 11 { - switch { - case tree.Field2 == nil: - tree.LegacyRouteHint = "legacy_default_group" - case *tree.Field2 == 1: - tree.LegacyRouteHint = "legacy_aws_group" - case *tree.Field2 == 2 && tree.EncodingLayers == 2: - tree.LegacyRouteHint = "legacy_vertex_direct" - case *tree.Field2 == 2 && tree.EncodingLayers == 1: - tree.LegacyRouteHint = "legacy_vertex_proxy" - } - } - - return tree, nil -} - -func extractBytesField(msg []byte, fieldNum protowire.Number, scope string) ([]byte, error) { - var value []byte - err := walkProtobufFields(msg, func(num protowire.Number, typ protowire.Type, raw []byte) error { - if num != fieldNum { - return nil - } - if typ != protowire.BytesType { - return fmt.Errorf("invalid Claude signature: %s field %d must be bytes", scope, fieldNum) - } - bytesValue, err := decodeBytesField(raw, fmt.Sprintf("%s field %d", scope, fieldNum)) - if err != nil { - return err - } - value = bytesValue - return nil - }) - if err != nil { - return nil, err - } - if value == nil { - return nil, fmt.Errorf("invalid Claude signature: missing %s field %d", scope, fieldNum) - } - return value, nil -} - -func walkProtobufFields(msg []byte, visit func(num protowire.Number, typ protowire.Type, raw []byte) error) error { - for offset := 0; offset < len(msg); { - num, typ, n := protowire.ConsumeTag(msg[offset:]) - if n < 0 { - return fmt.Errorf("invalid Claude signature: malformed protobuf tag: %w", protowire.ParseError(n)) - } - offset += n - valueLen := protowire.ConsumeFieldValue(num, typ, msg[offset:]) - if valueLen < 0 { - return fmt.Errorf("invalid Claude signature: malformed protobuf field %d: %w", num, protowire.ParseError(valueLen)) - } - fieldRaw := msg[offset : offset+valueLen] - if err := visit(num, typ, fieldRaw); err != nil { - return err - } - offset += valueLen - } - return nil -} - -func decodeVarintField(raw []byte, label string) (uint64, error) { - value, n := protowire.ConsumeVarint(raw) - if n < 0 { - return 0, fmt.Errorf("invalid Claude signature: failed to decode %s: %w", label, protowire.ParseError(n)) - } - return value, nil -} - -func decodeBytesField(raw []byte, label string) ([]byte, error) { - value, n := protowire.ConsumeBytes(raw) - if n < 0 { - return nil, fmt.Errorf("invalid Claude signature: failed to decode %s: %w", label, protowire.ParseError(n)) - } - return value, nil +func claudeBypassSignatureValidationOptions() signature.ClaudeSignatureValidationOptions { + return signature.ClaudeSignatureValidationOptions{Strict: cache.SignatureBypassStrictMode()} }