Fix Gemini thought signature replay

This commit is contained in:
Alishahryar1
2026-05-31 15:13:37 -07:00
parent 885c26d977
commit fedcc0a32b
8 changed files with 472 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
"""Message and tool format converters."""
import json
from copy import deepcopy
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any
@@ -83,6 +84,24 @@ def _think_tag_content(reasoning: str) -> str:
return f"<think>\n{reasoning}\n</think>"
def _tool_call_from_tool_use(block: Any) -> dict[str, Any]:
tool_input = get_block_attr(block, "input", {})
tool_call: dict[str, Any] = {
"id": get_block_attr(block, "id"),
"type": "function",
"function": {
"name": get_block_attr(block, "name"),
"arguments": json.dumps(tool_input)
if isinstance(tool_input, dict)
else str(tool_input),
},
}
extra_content = get_block_attr(block, "extra_content", None)
if isinstance(extra_content, dict) and extra_content:
tool_call["extra_content"] = deepcopy(extra_content)
return tool_call
@dataclass
class _PendingAfterTools:
"""Assistant content that appears after ``tool_use`` in an Anthropic message.
@@ -112,23 +131,11 @@ def _index_first_tool_use(blocks: list[Any]) -> int | None:
def _iter_tool_uses_in_order(blocks: list[Any]) -> list[dict[str, Any]]:
tool_calls: list[dict[str, Any]] = []
for block in blocks:
if get_block_type(block) == "tool_use":
tool_input = get_block_attr(block, "input", {})
tool_calls.append(
{
"id": get_block_attr(block, "id"),
"type": "function",
"function": {
"name": get_block_attr(block, "name"),
"arguments": json.dumps(tool_input)
if isinstance(tool_input, dict)
else str(tool_input),
},
}
)
return tool_calls
return [
_tool_call_from_tool_use(block)
for block in blocks
if get_block_type(block) == "tool_use"
]
def _deferred_post_tool_blocks(
@@ -362,19 +369,7 @@ class AnthropicToOpenAIConverter:
# or reasoning_content for OpenAI chat upstreams.
continue
elif block_type == "tool_use":
tool_input = get_block_attr(block, "input", {})
tool_calls.append(
{
"id": get_block_attr(block, "id"),
"type": "function",
"function": {
"name": get_block_attr(block, "name"),
"arguments": json.dumps(tool_input)
if isinstance(tool_input, dict)
else str(tool_input),
},
}
)
tool_calls.append(_tool_call_from_tool_use(block))
else:
_assert_no_forbidden_assistant_block(block)

View File

@@ -4,6 +4,7 @@ import hashlib
import json
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
@@ -54,6 +55,7 @@ class ToolCallState:
block_index: int
tool_id: str
name: str
extra_content: dict[str, Any] | None = None
contents: list[str] = field(default_factory=list)
started: bool = False
task_arg_buffer: str = ""
@@ -90,6 +92,15 @@ class ContentBlockManager:
state = self.ensure_tool_state(index)
state.tool_id = str(tool_id)
def set_tool_extra_content(
self, index: int, extra_content: dict[str, Any] | None
) -> None:
"""Record provider-specific OpenAI tool-call metadata before block start."""
if not extra_content:
return
state = self.ensure_tool_state(index)
state.extra_content = extra_content
def register_tool_name(self, index: int, name: str) -> None:
"""Record tool name fragments as they arrive from chunked OpenAI streams.
@@ -237,6 +248,9 @@ class SSEBuilder:
content_block["id"] = kwargs.get("id", "")
content_block["name"] = kwargs.get("name", "")
content_block["input"] = kwargs.get("input", {})
extra_content = kwargs.get("extra_content")
if isinstance(extra_content, dict) and extra_content:
content_block["extra_content"] = extra_content
return self._format_event(
"content_block_start",
@@ -302,21 +316,37 @@ class SSEBuilder:
self.blocks.text_started = False
return self.content_block_stop(self.blocks.text_index)
def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str:
def start_tool_block(
self,
tool_index: int,
tool_id: str,
name: str,
*,
extra_content: dict[str, Any] | None = None,
) -> str:
block_idx = self.blocks.allocate_index()
if tool_index in self.blocks.tool_states:
state = self.blocks.tool_states[tool_index]
state.block_index = block_idx
state.tool_id = tool_id
if extra_content:
state.extra_content = extra_content
state.started = True
else:
self.blocks.tool_states[tool_index] = ToolCallState(
block_index=block_idx,
tool_id=tool_id,
name=name,
extra_content=extra_content,
started=True,
)
return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name)
return self.content_block_start(
block_idx,
"tool_use",
id=tool_id,
name=name,
extra_content=extra_content,
)
def emit_tool_delta(self, tool_index: int, partial_json: str) -> str:
state = self.blocks.tool_states[tool_index]

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from providers.base import ProviderConfig
@@ -10,6 +11,8 @@ from providers.openai_compat import OpenAIChatTransport
from .request import build_request_body
_MAX_TOOL_CALL_EXTRA_CONTENT_CACHE = 4096
class GeminiProvider(OpenAIChatTransport):
"""Gemini API using ``https://generativelanguage.googleapis.com/v1beta/openai/``."""
@@ -21,6 +24,20 @@ class GeminiProvider(OpenAIChatTransport):
base_url=config.base_url or GEMINI_DEFAULT_BASE,
api_key=config.api_key,
)
self._tool_call_extra_content_by_id: dict[str, dict[str, Any]] = {}
def _record_tool_call_extra_content(
self, tool_call_id: str, extra_content: dict[str, Any]
) -> None:
if (
tool_call_id not in self._tool_call_extra_content_by_id
and len(self._tool_call_extra_content_by_id)
>= _MAX_TOOL_CALL_EXTRA_CONTENT_CACHE
):
self._tool_call_extra_content_by_id.pop(
next(iter(self._tool_call_extra_content_by_id))
)
self._tool_call_extra_content_by_id[tool_call_id] = deepcopy(extra_content)
def _build_request_body(
self, request: Any, thinking_enabled: bool | None = None
@@ -28,4 +45,5 @@ class GeminiProvider(OpenAIChatTransport):
return build_request_body(
request,
thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
tool_call_extra_content_by_id=self._tool_call_extra_content_by_id,
)

View File

@@ -11,6 +11,8 @@ from core.anthropic import ReasoningReplayMode, build_base_request_body
from core.anthropic.conversion import OpenAIConversionError
from providers.exceptions import InvalidRequestError
GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR = "skip_thought_signature_validator"
def _ensure_dict(container: dict[str, Any], key: str) -> dict[str, Any]:
value = container.get(key)
@@ -30,7 +32,130 @@ def _apply_thinking_config(extra_body: dict[str, Any]) -> None:
thinking_cfg.setdefault("include_thoughts", True)
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
def _is_gemini_3_model(model: Any) -> bool:
return "gemini-3" in str(model).lower()
def _thought_signature_from_extra_content(extra_content: Any) -> str | None:
if not isinstance(extra_content, dict):
return None
google = extra_content.get("google")
if not isinstance(google, dict):
return None
signature = google.get("thought_signature")
return signature if isinstance(signature, str) and signature else None
def _tool_call_thought_signature(tool_call: dict[str, Any]) -> str | None:
return _thought_signature_from_extra_content(tool_call.get("extra_content"))
def _set_tool_call_thought_signature(tool_call: dict[str, Any], signature: str) -> None:
extra_content = tool_call.get("extra_content")
if not isinstance(extra_content, dict):
extra_content = {}
tool_call["extra_content"] = extra_content
google = extra_content.get("google")
if not isinstance(google, dict):
google = {}
extra_content["google"] = google
google["thought_signature"] = signature
def _message_has_standard_user_content(message: dict[str, Any]) -> bool:
if message.get("role") != "user":
return False
content = message.get("content")
if isinstance(content, str):
return bool(content.strip())
if isinstance(content, list):
return any(
isinstance(part, dict)
and isinstance(part.get("text"), str)
and bool(part["text"].strip())
for part in content
)
return False
def _current_turn_start_index(messages: list[Any]) -> int:
for index in range(len(messages) - 1, -1, -1):
message = messages[index]
if isinstance(message, dict) and _message_has_standard_user_content(message):
return index
return -1
def _apply_cached_tool_call_signatures(
messages: list[Any], tool_call_extra_content_by_id: dict[str, dict[str, Any]]
) -> None:
if not tool_call_extra_content_by_id:
return
for message in messages:
if not isinstance(message, dict) or message.get("role") != "assistant":
continue
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list):
continue
for tool_call in tool_calls:
if not isinstance(tool_call, dict) or _tool_call_thought_signature(
tool_call
):
continue
tool_call_id = tool_call.get("id")
if tool_call_id is None:
continue
cached_extra_content = tool_call_extra_content_by_id.get(str(tool_call_id))
if not cached_extra_content:
continue
cached_signature = _thought_signature_from_extra_content(
cached_extra_content
)
if cached_signature:
tool_call["extra_content"] = deepcopy(cached_extra_content)
def _apply_gemini_3_missing_current_turn_signatures(
body: dict[str, Any], messages: list[Any]
) -> None:
if not _is_gemini_3_model(body.get("model")):
return
start_index = _current_turn_start_index(messages)
for message in messages[start_index + 1 :]:
if not isinstance(message, dict) or message.get("role") != "assistant":
continue
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
continue
first_tool_call = tool_calls[0]
if not isinstance(first_tool_call, dict):
continue
if _tool_call_thought_signature(first_tool_call):
continue
_set_tool_call_thought_signature(
first_tool_call, GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR
)
def _apply_gemini_tool_call_signatures(
body: dict[str, Any],
*,
tool_call_extra_content_by_id: dict[str, dict[str, Any]] | None,
) -> None:
messages = body.get("messages")
if not isinstance(messages, list):
return
_apply_cached_tool_call_signatures(messages, tool_call_extra_content_by_id or {})
_apply_gemini_3_missing_current_turn_signatures(body, messages)
def build_request_body(
request_data: Any,
*,
thinking_enabled: bool,
tool_call_extra_content_by_id: dict[str, dict[str, Any]] | None = None,
) -> dict:
"""Build OpenAI-format request body from an Anthropic request for Gemini."""
logger.debug(
"GEMINI_REQUEST: conversion start model={} msgs={}",
@@ -60,6 +185,11 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
if extra_body:
body["extra_body"] = extra_body
_apply_gemini_tool_call_signatures(
body,
tool_call_extra_content_by_id=tool_call_extra_content_by_id,
)
logger.debug(
"GEMINI_REQUEST: conversion done model={} msgs={} tools={}",
body.get("model"),

View File

@@ -57,6 +57,30 @@ def _iter_heuristic_tool_use_sse(
yield sse.content_block_stop(block_idx)
def _tool_call_extra_content(tool_call: Any) -> dict[str, Any] | None:
if isinstance(tool_call, dict):
value = tool_call.get("extra_content")
return value if isinstance(value, dict) else None
value = getattr(tool_call, "extra_content", None)
if isinstance(value, dict):
return value
model_extra = getattr(tool_call, "model_extra", None)
if isinstance(model_extra, dict):
value = model_extra.get("extra_content")
if isinstance(value, dict):
return value
pydantic_extra = getattr(tool_call, "__pydantic_extra__", None)
if isinstance(pydantic_extra, dict):
value = pydantic_extra.get("extra_content")
if isinstance(value, dict):
return value
return None
class OpenAIChatTransport(BaseProvider):
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, …)."""
@@ -133,6 +157,11 @@ class OpenAIChatTransport(BaseProvider):
"""Return the body passed to the upstream OpenAI-compatible client."""
return body
def _record_tool_call_extra_content(
self, tool_call_id: str, extra_content: dict[str, Any]
) -> None:
"""Hook for providers that must replay OpenAI tool-call metadata later."""
def _tool_argument_aliases(self, body: dict[str, Any]) -> dict[str, dict[str, str]]:
"""Return provider-specific per-tool argument aliases for this request."""
return {}
@@ -246,6 +275,15 @@ class OpenAIChatTransport(BaseProvider):
if tc.get("id") is not None:
sse.blocks.set_stream_tool_id(tc_index, tc.get("id"))
raw_extra_content = tc.get("extra_content")
extra_content = (
raw_extra_content
if isinstance(raw_extra_content, dict) and raw_extra_content
else None
)
if extra_content:
sse.blocks.set_tool_extra_content(tc_index, extra_content)
if incoming_name is not None:
sse.blocks.register_tool_name(tc_index, incoming_name)
@@ -260,7 +298,15 @@ class OpenAIChatTransport(BaseProvider):
if name_ok:
tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}"
display_name = (resolved_name or "").strip() or "tool_call"
yield sse.start_tool_block(tc_index, tool_id, display_name)
start_extra_content = state.extra_content if state else extra_content
if start_extra_content:
self._record_tool_call_extra_content(tool_id, start_extra_content)
yield sse.start_tool_block(
tc_index,
tool_id,
display_name,
extra_content=start_extra_content,
)
state = sse.blocks.tool_states[tc_index]
if state.pre_start_args:
pre = state.pre_start_args
@@ -274,6 +320,8 @@ class OpenAIChatTransport(BaseProvider):
)
state = sse.blocks.tool_states.get(tc_index)
if state is not None and state.tool_id and extra_content:
self._record_tool_call_extra_content(state.tool_id, extra_content)
if not arguments:
return
if state is None or not state.started:
@@ -441,6 +489,7 @@ class OpenAIChatTransport(BaseProvider):
for event in sse.close_content_blocks():
yield event
for tc in delta.tool_calls:
extra_content = _tool_call_extra_content(tc)
tc_info = {
"index": tc.index,
"id": tc.id,
@@ -449,6 +498,8 @@ class OpenAIChatTransport(BaseProvider):
"arguments": tc.function.arguments,
},
}
if extra_content:
tc_info["extra_content"] = extra_content
for event in self._process_tool_call(
tc_info,
sse,

View File

@@ -342,6 +342,24 @@ def test_convert_assistant_message_tool_use():
assert json.loads(tc["function"]["arguments"]) == {"query": "python"}
def test_convert_assistant_tool_use_preserves_extra_content():
content = [
MockBlock(
type="tool_use",
id="call_1",
name="search",
input={"query": "python"},
extra_content={"google": {"thought_signature": "sig"}},
),
]
messages = [MockMessage("assistant", content)]
result = AnthropicToOpenAIConverter.convert_messages(messages)
assert result[0]["tool_calls"][0]["extra_content"] == {
"google": {"thought_signature": "sig"}
}
def test_convert_assistant_message_empty_content():
# Verify that empty content becomes a single space (NIM requirement)
# if no tool calls are present.

View File

@@ -7,6 +7,7 @@ import pytest
from providers.base import ProviderConfig
from providers.gemini import GEMINI_DEFAULT_BASE, GeminiProvider
from providers.gemini.request import GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR
class MockMessage:
@@ -194,6 +195,137 @@ def test_build_request_body_merges_caller_nested_google(gemini_provider):
assert thinking_config.get("include_thoughts") is True
def test_build_request_body_preserves_tool_call_extra_content(gemini_provider):
req = MockRequest(
system=None,
messages=[
MockMessage("user", "Find files"),
MockMessage(
"assistant",
[
{
"type": "tool_use",
"id": "function-call-1",
"name": "Glob",
"input": {"pattern": "*.py"},
"extra_content": {
"google": {"thought_signature": "sig-from-client"}
},
}
],
),
MockMessage(
"user",
[
{
"type": "tool_result",
"tool_use_id": "function-call-1",
"content": "[]",
}
],
),
],
)
body = gemini_provider._build_request_body(req)
tool_call = body["messages"][1]["tool_calls"][0]
assert tool_call["extra_content"] == {
"google": {"thought_signature": "sig-from-client"}
}
def test_build_request_body_uses_cached_tool_call_signature(gemini_provider):
gemini_provider._record_tool_call_extra_content(
"function-call-1", {"google": {"thought_signature": "sig-from-cache"}}
)
req = MockRequest(
system=None,
messages=[
MockMessage("user", "Find files"),
MockMessage(
"assistant",
[
{
"type": "tool_use",
"id": "function-call-1",
"name": "Glob",
"input": {"pattern": "*.py"},
}
],
),
MockMessage(
"user",
[
{
"type": "tool_result",
"tool_use_id": "function-call-1",
"content": "[]",
}
],
),
],
)
body = gemini_provider._build_request_body(req)
tool_call = body["messages"][1]["tool_calls"][0]
assert tool_call["extra_content"] == {
"google": {"thought_signature": "sig-from-cache"}
}
def test_build_request_body_adds_gemini3_current_turn_fallback_signature(
gemini_provider,
):
req = MockRequest(
system=None,
messages=[
MockMessage("user", "Find files"),
MockMessage(
"assistant",
[
{
"type": "tool_use",
"id": "function-call-1",
"name": "Glob",
"input": {"pattern": "*.py"},
},
{
"type": "tool_use",
"id": "function-call-2",
"name": "Read",
"input": {"file_path": "a.py"},
},
],
),
MockMessage(
"user",
[
{
"type": "tool_result",
"tool_use_id": "function-call-1",
"content": "[]",
},
{
"type": "tool_result",
"tool_use_id": "function-call-2",
"content": "contents",
},
],
),
],
)
body = gemini_provider._build_request_body(req)
tool_calls = body["messages"][1]["tool_calls"]
assert tool_calls[0]["extra_content"] == {
"google": {"thought_signature": GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR}
}
assert "extra_content" not in tool_calls[1]
@pytest.mark.asyncio
async def test_stream_response_text(gemini_provider):
req = MockRequest()
@@ -237,6 +369,54 @@ async def test_stream_response_text(gemini_provider):
assert thinking_config.get("include_thoughts") is True
@pytest.mark.asyncio
async def test_stream_response_preserves_tool_call_extra_content(gemini_provider):
req = MockRequest()
mock_tc = MagicMock()
mock_tc.index = 0
mock_tc.id = "function-call-1"
mock_tc.extra_content = {"google": {"thought_signature": "sig-stream"}}
mock_tc.function = MagicMock()
mock_tc.function.name = "Glob"
mock_tc.function.arguments = '{"pattern":"*.py"}'
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(
content=None,
reasoning_content=None,
tool_calls=[mock_tc],
),
finish_reason="tool_calls",
)
]
mock_chunk.usage = MagicMock(completion_tokens=5, prompt_tokens=10)
async def mock_stream():
yield mock_chunk
with patch.object(
gemini_provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_stream()
events = [event async for event in gemini_provider.stream_response(req)]
tool_starts = [
event
for event in events
if '"content_block_start"' in event and '"tool_use"' in event
]
assert any(
'"extra_content"' in event and "sig-stream" in event for event in tool_starts
)
assert gemini_provider._tool_call_extra_content_by_id["function-call-1"] == {
"google": {"thought_signature": "sig-stream"}
}
@pytest.mark.asyncio
async def test_stream_response_reasoning_content(gemini_provider):
req = MockRequest()

View File

@@ -173,6 +173,22 @@ class TestSSEBuilderContentBlocks:
assert data["content_block"]["name"] == "Read"
assert data["content_block"]["input"] == {}
def test_content_block_start_tool_use_extra_content(self):
builder = SSEBuilder("msg_1", "model")
sse = builder.content_block_start(
2,
"tool_use",
id="tool_123",
name="Read",
input={},
extra_content={"google": {"thought_signature": "sig"}},
)
data = _parse_sse(sse)
assert data["content_block"]["extra_content"] == {
"google": {"thought_signature": "sig"}
}
def test_content_block_delta_text(self):
builder = SSEBuilder("msg_1", "model")
sse = builder.content_block_delta(0, "text_delta", "hello world")