mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-20 02:17:27 +08:00
Fix Gemini thought signature replay
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Message and tool format converters."""
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
@@ -83,6 +84,24 @@ def _think_tag_content(reasoning: str) -> str:
|
||||
return f"<think>\n{reasoning}\n</think>"
|
||||
|
||||
|
||||
def _tool_call_from_tool_use(block: Any) -> dict[str, Any]:
|
||||
tool_input = get_block_attr(block, "input", {})
|
||||
tool_call: dict[str, Any] = {
|
||||
"id": get_block_attr(block, "id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_block_attr(block, "name"),
|
||||
"arguments": json.dumps(tool_input)
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input),
|
||||
},
|
||||
}
|
||||
extra_content = get_block_attr(block, "extra_content", None)
|
||||
if isinstance(extra_content, dict) and extra_content:
|
||||
tool_call["extra_content"] = deepcopy(extra_content)
|
||||
return tool_call
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PendingAfterTools:
|
||||
"""Assistant content that appears after ``tool_use`` in an Anthropic message.
|
||||
@@ -112,23 +131,11 @@ def _index_first_tool_use(blocks: list[Any]) -> int | None:
|
||||
|
||||
|
||||
def _iter_tool_uses_in_order(blocks: list[Any]) -> list[dict[str, Any]]:
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for block in blocks:
|
||||
if get_block_type(block) == "tool_use":
|
||||
tool_input = get_block_attr(block, "input", {})
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": get_block_attr(block, "id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_block_attr(block, "name"),
|
||||
"arguments": json.dumps(tool_input)
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input),
|
||||
},
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
return [
|
||||
_tool_call_from_tool_use(block)
|
||||
for block in blocks
|
||||
if get_block_type(block) == "tool_use"
|
||||
]
|
||||
|
||||
|
||||
def _deferred_post_tool_blocks(
|
||||
@@ -362,19 +369,7 @@ class AnthropicToOpenAIConverter:
|
||||
# or reasoning_content for OpenAI chat upstreams.
|
||||
continue
|
||||
elif block_type == "tool_use":
|
||||
tool_input = get_block_attr(block, "input", {})
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": get_block_attr(block, "id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_block_attr(block, "name"),
|
||||
"arguments": json.dumps(tool_input)
|
||||
if isinstance(tool_input, dict)
|
||||
else str(tool_input),
|
||||
},
|
||||
}
|
||||
)
|
||||
tool_calls.append(_tool_call_from_tool_use(block))
|
||||
else:
|
||||
_assert_no_forbidden_assistant_block(block)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import hashlib
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -54,6 +55,7 @@ class ToolCallState:
|
||||
block_index: int
|
||||
tool_id: str
|
||||
name: str
|
||||
extra_content: dict[str, Any] | None = None
|
||||
contents: list[str] = field(default_factory=list)
|
||||
started: bool = False
|
||||
task_arg_buffer: str = ""
|
||||
@@ -90,6 +92,15 @@ class ContentBlockManager:
|
||||
state = self.ensure_tool_state(index)
|
||||
state.tool_id = str(tool_id)
|
||||
|
||||
def set_tool_extra_content(
|
||||
self, index: int, extra_content: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Record provider-specific OpenAI tool-call metadata before block start."""
|
||||
if not extra_content:
|
||||
return
|
||||
state = self.ensure_tool_state(index)
|
||||
state.extra_content = extra_content
|
||||
|
||||
def register_tool_name(self, index: int, name: str) -> None:
|
||||
"""Record tool name fragments as they arrive from chunked OpenAI streams.
|
||||
|
||||
@@ -237,6 +248,9 @@ class SSEBuilder:
|
||||
content_block["id"] = kwargs.get("id", "")
|
||||
content_block["name"] = kwargs.get("name", "")
|
||||
content_block["input"] = kwargs.get("input", {})
|
||||
extra_content = kwargs.get("extra_content")
|
||||
if isinstance(extra_content, dict) and extra_content:
|
||||
content_block["extra_content"] = extra_content
|
||||
|
||||
return self._format_event(
|
||||
"content_block_start",
|
||||
@@ -302,21 +316,37 @@ class SSEBuilder:
|
||||
self.blocks.text_started = False
|
||||
return self.content_block_stop(self.blocks.text_index)
|
||||
|
||||
def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str:
|
||||
def start_tool_block(
|
||||
self,
|
||||
tool_index: int,
|
||||
tool_id: str,
|
||||
name: str,
|
||||
*,
|
||||
extra_content: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
block_idx = self.blocks.allocate_index()
|
||||
if tool_index in self.blocks.tool_states:
|
||||
state = self.blocks.tool_states[tool_index]
|
||||
state.block_index = block_idx
|
||||
state.tool_id = tool_id
|
||||
if extra_content:
|
||||
state.extra_content = extra_content
|
||||
state.started = True
|
||||
else:
|
||||
self.blocks.tool_states[tool_index] = ToolCallState(
|
||||
block_index=block_idx,
|
||||
tool_id=tool_id,
|
||||
name=name,
|
||||
extra_content=extra_content,
|
||||
started=True,
|
||||
)
|
||||
return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name)
|
||||
return self.content_block_start(
|
||||
block_idx,
|
||||
"tool_use",
|
||||
id=tool_id,
|
||||
name=name,
|
||||
extra_content=extra_content,
|
||||
)
|
||||
|
||||
def emit_tool_delta(self, tool_index: int, partial_json: str) -> str:
|
||||
state = self.blocks.tool_states[tool_index]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
@@ -10,6 +11,8 @@ from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
_MAX_TOOL_CALL_EXTRA_CONTENT_CACHE = 4096
|
||||
|
||||
|
||||
class GeminiProvider(OpenAIChatTransport):
|
||||
"""Gemini API using ``https://generativelanguage.googleapis.com/v1beta/openai/``."""
|
||||
@@ -21,6 +24,20 @@ class GeminiProvider(OpenAIChatTransport):
|
||||
base_url=config.base_url or GEMINI_DEFAULT_BASE,
|
||||
api_key=config.api_key,
|
||||
)
|
||||
self._tool_call_extra_content_by_id: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _record_tool_call_extra_content(
|
||||
self, tool_call_id: str, extra_content: dict[str, Any]
|
||||
) -> None:
|
||||
if (
|
||||
tool_call_id not in self._tool_call_extra_content_by_id
|
||||
and len(self._tool_call_extra_content_by_id)
|
||||
>= _MAX_TOOL_CALL_EXTRA_CONTENT_CACHE
|
||||
):
|
||||
self._tool_call_extra_content_by_id.pop(
|
||||
next(iter(self._tool_call_extra_content_by_id))
|
||||
)
|
||||
self._tool_call_extra_content_by_id[tool_call_id] = deepcopy(extra_content)
|
||||
|
||||
def _build_request_body(
|
||||
self, request: Any, thinking_enabled: bool | None = None
|
||||
@@ -28,4 +45,5 @@ class GeminiProvider(OpenAIChatTransport):
|
||||
return build_request_body(
|
||||
request,
|
||||
thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
|
||||
tool_call_extra_content_by_id=self._tool_call_extra_content_by_id,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,8 @@ from core.anthropic import ReasoningReplayMode, build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR = "skip_thought_signature_validator"
|
||||
|
||||
|
||||
def _ensure_dict(container: dict[str, Any], key: str) -> dict[str, Any]:
|
||||
value = container.get(key)
|
||||
@@ -30,7 +32,130 @@ def _apply_thinking_config(extra_body: dict[str, Any]) -> None:
|
||||
thinking_cfg.setdefault("include_thoughts", True)
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
def _is_gemini_3_model(model: Any) -> bool:
|
||||
return "gemini-3" in str(model).lower()
|
||||
|
||||
|
||||
def _thought_signature_from_extra_content(extra_content: Any) -> str | None:
|
||||
if not isinstance(extra_content, dict):
|
||||
return None
|
||||
google = extra_content.get("google")
|
||||
if not isinstance(google, dict):
|
||||
return None
|
||||
signature = google.get("thought_signature")
|
||||
return signature if isinstance(signature, str) and signature else None
|
||||
|
||||
|
||||
def _tool_call_thought_signature(tool_call: dict[str, Any]) -> str | None:
|
||||
return _thought_signature_from_extra_content(tool_call.get("extra_content"))
|
||||
|
||||
|
||||
def _set_tool_call_thought_signature(tool_call: dict[str, Any], signature: str) -> None:
|
||||
extra_content = tool_call.get("extra_content")
|
||||
if not isinstance(extra_content, dict):
|
||||
extra_content = {}
|
||||
tool_call["extra_content"] = extra_content
|
||||
google = extra_content.get("google")
|
||||
if not isinstance(google, dict):
|
||||
google = {}
|
||||
extra_content["google"] = google
|
||||
google["thought_signature"] = signature
|
||||
|
||||
|
||||
def _message_has_standard_user_content(message: dict[str, Any]) -> bool:
|
||||
if message.get("role") != "user":
|
||||
return False
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip())
|
||||
if isinstance(content, list):
|
||||
return any(
|
||||
isinstance(part, dict)
|
||||
and isinstance(part.get("text"), str)
|
||||
and bool(part["text"].strip())
|
||||
for part in content
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _current_turn_start_index(messages: list[Any]) -> int:
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
message = messages[index]
|
||||
if isinstance(message, dict) and _message_has_standard_user_content(message):
|
||||
return index
|
||||
return -1
|
||||
|
||||
|
||||
def _apply_cached_tool_call_signatures(
|
||||
messages: list[Any], tool_call_extra_content_by_id: dict[str, dict[str, Any]]
|
||||
) -> None:
|
||||
if not tool_call_extra_content_by_id:
|
||||
return
|
||||
for message in messages:
|
||||
if not isinstance(message, dict) or message.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
continue
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict) or _tool_call_thought_signature(
|
||||
tool_call
|
||||
):
|
||||
continue
|
||||
tool_call_id = tool_call.get("id")
|
||||
if tool_call_id is None:
|
||||
continue
|
||||
cached_extra_content = tool_call_extra_content_by_id.get(str(tool_call_id))
|
||||
if not cached_extra_content:
|
||||
continue
|
||||
cached_signature = _thought_signature_from_extra_content(
|
||||
cached_extra_content
|
||||
)
|
||||
if cached_signature:
|
||||
tool_call["extra_content"] = deepcopy(cached_extra_content)
|
||||
|
||||
|
||||
def _apply_gemini_3_missing_current_turn_signatures(
|
||||
body: dict[str, Any], messages: list[Any]
|
||||
) -> None:
|
||||
if not _is_gemini_3_model(body.get("model")):
|
||||
return
|
||||
|
||||
start_index = _current_turn_start_index(messages)
|
||||
for message in messages[start_index + 1 :]:
|
||||
if not isinstance(message, dict) or message.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
continue
|
||||
first_tool_call = tool_calls[0]
|
||||
if not isinstance(first_tool_call, dict):
|
||||
continue
|
||||
if _tool_call_thought_signature(first_tool_call):
|
||||
continue
|
||||
_set_tool_call_thought_signature(
|
||||
first_tool_call, GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR
|
||||
)
|
||||
|
||||
|
||||
def _apply_gemini_tool_call_signatures(
|
||||
body: dict[str, Any],
|
||||
*,
|
||||
tool_call_extra_content_by_id: dict[str, dict[str, Any]] | None,
|
||||
) -> None:
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
_apply_cached_tool_call_signatures(messages, tool_call_extra_content_by_id or {})
|
||||
_apply_gemini_3_missing_current_turn_signatures(body, messages)
|
||||
|
||||
|
||||
def build_request_body(
|
||||
request_data: Any,
|
||||
*,
|
||||
thinking_enabled: bool,
|
||||
tool_call_extra_content_by_id: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict:
|
||||
"""Build OpenAI-format request body from an Anthropic request for Gemini."""
|
||||
logger.debug(
|
||||
"GEMINI_REQUEST: conversion start model={} msgs={}",
|
||||
@@ -60,6 +185,11 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
if extra_body:
|
||||
body["extra_body"] = extra_body
|
||||
|
||||
_apply_gemini_tool_call_signatures(
|
||||
body,
|
||||
tool_call_extra_content_by_id=tool_call_extra_content_by_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"GEMINI_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
|
||||
@@ -57,6 +57,30 @@ def _iter_heuristic_tool_use_sse(
|
||||
yield sse.content_block_stop(block_idx)
|
||||
|
||||
|
||||
def _tool_call_extra_content(tool_call: Any) -> dict[str, Any] | None:
|
||||
if isinstance(tool_call, dict):
|
||||
value = tool_call.get("extra_content")
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
value = getattr(tool_call, "extra_content", None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
model_extra = getattr(tool_call, "model_extra", None)
|
||||
if isinstance(model_extra, dict):
|
||||
value = model_extra.get("extra_content")
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
pydantic_extra = getattr(tool_call, "__pydantic_extra__", None)
|
||||
if isinstance(pydantic_extra, dict):
|
||||
value = pydantic_extra.get("extra_content")
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIChatTransport(BaseProvider):
|
||||
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, …)."""
|
||||
|
||||
@@ -133,6 +157,11 @@ class OpenAIChatTransport(BaseProvider):
|
||||
"""Return the body passed to the upstream OpenAI-compatible client."""
|
||||
return body
|
||||
|
||||
def _record_tool_call_extra_content(
|
||||
self, tool_call_id: str, extra_content: dict[str, Any]
|
||||
) -> None:
|
||||
"""Hook for providers that must replay OpenAI tool-call metadata later."""
|
||||
|
||||
def _tool_argument_aliases(self, body: dict[str, Any]) -> dict[str, dict[str, str]]:
|
||||
"""Return provider-specific per-tool argument aliases for this request."""
|
||||
return {}
|
||||
@@ -246,6 +275,15 @@ class OpenAIChatTransport(BaseProvider):
|
||||
if tc.get("id") is not None:
|
||||
sse.blocks.set_stream_tool_id(tc_index, tc.get("id"))
|
||||
|
||||
raw_extra_content = tc.get("extra_content")
|
||||
extra_content = (
|
||||
raw_extra_content
|
||||
if isinstance(raw_extra_content, dict) and raw_extra_content
|
||||
else None
|
||||
)
|
||||
if extra_content:
|
||||
sse.blocks.set_tool_extra_content(tc_index, extra_content)
|
||||
|
||||
if incoming_name is not None:
|
||||
sse.blocks.register_tool_name(tc_index, incoming_name)
|
||||
|
||||
@@ -260,7 +298,15 @@ class OpenAIChatTransport(BaseProvider):
|
||||
if name_ok:
|
||||
tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}"
|
||||
display_name = (resolved_name or "").strip() or "tool_call"
|
||||
yield sse.start_tool_block(tc_index, tool_id, display_name)
|
||||
start_extra_content = state.extra_content if state else extra_content
|
||||
if start_extra_content:
|
||||
self._record_tool_call_extra_content(tool_id, start_extra_content)
|
||||
yield sse.start_tool_block(
|
||||
tc_index,
|
||||
tool_id,
|
||||
display_name,
|
||||
extra_content=start_extra_content,
|
||||
)
|
||||
state = sse.blocks.tool_states[tc_index]
|
||||
if state.pre_start_args:
|
||||
pre = state.pre_start_args
|
||||
@@ -274,6 +320,8 @@ class OpenAIChatTransport(BaseProvider):
|
||||
)
|
||||
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if state is not None and state.tool_id and extra_content:
|
||||
self._record_tool_call_extra_content(state.tool_id, extra_content)
|
||||
if not arguments:
|
||||
return
|
||||
if state is None or not state.started:
|
||||
@@ -441,6 +489,7 @@ class OpenAIChatTransport(BaseProvider):
|
||||
for event in sse.close_content_blocks():
|
||||
yield event
|
||||
for tc in delta.tool_calls:
|
||||
extra_content = _tool_call_extra_content(tc)
|
||||
tc_info = {
|
||||
"index": tc.index,
|
||||
"id": tc.id,
|
||||
@@ -449,6 +498,8 @@ class OpenAIChatTransport(BaseProvider):
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
if extra_content:
|
||||
tc_info["extra_content"] = extra_content
|
||||
for event in self._process_tool_call(
|
||||
tc_info,
|
||||
sse,
|
||||
|
||||
@@ -342,6 +342,24 @@ def test_convert_assistant_message_tool_use():
|
||||
assert json.loads(tc["function"]["arguments"]) == {"query": "python"}
|
||||
|
||||
|
||||
def test_convert_assistant_tool_use_preserves_extra_content():
|
||||
content = [
|
||||
MockBlock(
|
||||
type="tool_use",
|
||||
id="call_1",
|
||||
name="search",
|
||||
input={"query": "python"},
|
||||
extra_content={"google": {"thought_signature": "sig"}},
|
||||
),
|
||||
]
|
||||
messages = [MockMessage("assistant", content)]
|
||||
result = AnthropicToOpenAIConverter.convert_messages(messages)
|
||||
|
||||
assert result[0]["tool_calls"][0]["extra_content"] == {
|
||||
"google": {"thought_signature": "sig"}
|
||||
}
|
||||
|
||||
|
||||
def test_convert_assistant_message_empty_content():
|
||||
# Verify that empty content becomes a single space (NIM requirement)
|
||||
# if no tool calls are present.
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.gemini import GEMINI_DEFAULT_BASE, GeminiProvider
|
||||
from providers.gemini.request import GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR
|
||||
|
||||
|
||||
class MockMessage:
|
||||
@@ -194,6 +195,137 @@ def test_build_request_body_merges_caller_nested_google(gemini_provider):
|
||||
assert thinking_config.get("include_thoughts") is True
|
||||
|
||||
|
||||
def test_build_request_body_preserves_tool_call_extra_content(gemini_provider):
|
||||
req = MockRequest(
|
||||
system=None,
|
||||
messages=[
|
||||
MockMessage("user", "Find files"),
|
||||
MockMessage(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "function-call-1",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "*.py"},
|
||||
"extra_content": {
|
||||
"google": {"thought_signature": "sig-from-client"}
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
MockMessage(
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "function-call-1",
|
||||
"content": "[]",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
body = gemini_provider._build_request_body(req)
|
||||
|
||||
tool_call = body["messages"][1]["tool_calls"][0]
|
||||
assert tool_call["extra_content"] == {
|
||||
"google": {"thought_signature": "sig-from-client"}
|
||||
}
|
||||
|
||||
|
||||
def test_build_request_body_uses_cached_tool_call_signature(gemini_provider):
|
||||
gemini_provider._record_tool_call_extra_content(
|
||||
"function-call-1", {"google": {"thought_signature": "sig-from-cache"}}
|
||||
)
|
||||
req = MockRequest(
|
||||
system=None,
|
||||
messages=[
|
||||
MockMessage("user", "Find files"),
|
||||
MockMessage(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "function-call-1",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "*.py"},
|
||||
}
|
||||
],
|
||||
),
|
||||
MockMessage(
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "function-call-1",
|
||||
"content": "[]",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
body = gemini_provider._build_request_body(req)
|
||||
|
||||
tool_call = body["messages"][1]["tool_calls"][0]
|
||||
assert tool_call["extra_content"] == {
|
||||
"google": {"thought_signature": "sig-from-cache"}
|
||||
}
|
||||
|
||||
|
||||
def test_build_request_body_adds_gemini3_current_turn_fallback_signature(
|
||||
gemini_provider,
|
||||
):
|
||||
req = MockRequest(
|
||||
system=None,
|
||||
messages=[
|
||||
MockMessage("user", "Find files"),
|
||||
MockMessage(
|
||||
"assistant",
|
||||
[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "function-call-1",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "*.py"},
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "function-call-2",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "a.py"},
|
||||
},
|
||||
],
|
||||
),
|
||||
MockMessage(
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "function-call-1",
|
||||
"content": "[]",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "function-call-2",
|
||||
"content": "contents",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
body = gemini_provider._build_request_body(req)
|
||||
|
||||
tool_calls = body["messages"][1]["tool_calls"]
|
||||
assert tool_calls[0]["extra_content"] == {
|
||||
"google": {"thought_signature": GEMINI_SKIP_THOUGHT_SIGNATURE_VALIDATOR}
|
||||
}
|
||||
assert "extra_content" not in tool_calls[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(gemini_provider):
|
||||
req = MockRequest()
|
||||
@@ -237,6 +369,54 @@ async def test_stream_response_text(gemini_provider):
|
||||
assert thinking_config.get("include_thoughts") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_preserves_tool_call_extra_content(gemini_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.index = 0
|
||||
mock_tc.id = "function-call-1"
|
||||
mock_tc.extra_content = {"google": {"thought_signature": "sig-stream"}}
|
||||
mock_tc.function = MagicMock()
|
||||
mock_tc.function.name = "Glob"
|
||||
mock_tc.function.arguments = '{"pattern":"*.py"}'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content=None,
|
||||
reasoning_content=None,
|
||||
tool_calls=[mock_tc],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=5, prompt_tokens=10)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
gemini_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [event async for event in gemini_provider.stream_response(req)]
|
||||
|
||||
tool_starts = [
|
||||
event
|
||||
for event in events
|
||||
if '"content_block_start"' in event and '"tool_use"' in event
|
||||
]
|
||||
assert any(
|
||||
'"extra_content"' in event and "sig-stream" in event for event in tool_starts
|
||||
)
|
||||
assert gemini_provider._tool_call_extra_content_by_id["function-call-1"] == {
|
||||
"google": {"thought_signature": "sig-stream"}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_reasoning_content(gemini_provider):
|
||||
req = MockRequest()
|
||||
|
||||
@@ -173,6 +173,22 @@ class TestSSEBuilderContentBlocks:
|
||||
assert data["content_block"]["name"] == "Read"
|
||||
assert data["content_block"]["input"] == {}
|
||||
|
||||
def test_content_block_start_tool_use_extra_content(self):
|
||||
builder = SSEBuilder("msg_1", "model")
|
||||
sse = builder.content_block_start(
|
||||
2,
|
||||
"tool_use",
|
||||
id="tool_123",
|
||||
name="Read",
|
||||
input={},
|
||||
extra_content={"google": {"thought_signature": "sig"}},
|
||||
)
|
||||
|
||||
data = _parse_sse(sse)
|
||||
assert data["content_block"]["extra_content"] == {
|
||||
"google": {"thought_signature": "sig"}
|
||||
}
|
||||
|
||||
def test_content_block_delta_text(self):
|
||||
builder = SSEBuilder("msg_1", "model")
|
||||
sse = builder.content_block_delta(0, "text_delta", "hello world")
|
||||
|
||||
Reference in New Issue
Block a user