fix(proxy): dedupe streaming message_delta (#2366)

- Deduplicate repeated upstream `finish_reason` chunks so only one Anthropic `message_delta` is emitted.
- Preserve late `choices: []` usage-only chunks before sending the final `message_delta`.
- Keep stream error paths from emitting successful terminal events.
- Add regressions for duplicate finish reasons, usage-only chunks, missing `[DONE]`, and truncated streams.
This commit is contained in:
codeasier
2026-04-28 17:08:43 +08:00
committed by GitHub
parent 4536b95ac9
commit 6441bc5c01

View File

@@ -6,14 +6,17 @@ use crate::proxy::sse::{strip_sse_field, take_sse_block};
use bytes::Bytes; use bytes::Bytes;
use futures::stream::{Stream, StreamExt}; use futures::stream::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::{json, Value};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
/// OpenAI 流式响应数据结构 /// OpenAI 流式响应数据结构
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct OpenAIStreamChunk { struct OpenAIStreamChunk {
#[serde(default)]
id: String, id: String,
#[serde(default)]
model: String, model: String,
#[serde(default)]
choices: Vec<StreamChoice>, choices: Vec<StreamChoice>,
#[serde(default)] #[serde(default)]
usage: Option<Usage>, usage: Option<Usage>,
@@ -95,6 +98,20 @@ struct ToolBlockState {
/// 无限空白 bug 的连续空白字符阈值 /// 无限空白 bug 的连续空白字符阈值
const INFINITE_WHITESPACE_THRESHOLD: usize = 20; const INFINITE_WHITESPACE_THRESHOLD: usize = 20;
fn build_anthropic_usage_json(usage: &Usage) -> Value {
let mut usage_json = json!({
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens
});
if let Some(cached) = extract_cache_read_tokens(usage) {
usage_json["cache_read_input_tokens"] = json!(cached);
}
if let Some(created) = usage.cache_creation_input_tokens {
usage_json["cache_creation_input_tokens"] = json!(created);
}
usage_json
}
/// 创建 Anthropic SSE 流 /// 创建 Anthropic SSE 流
pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>( pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
stream: impl Stream<Item = Result<Bytes, E>> + Send + 'static, stream: impl Stream<Item = Result<Bytes, E>> + Send + 'static,
@@ -106,6 +123,16 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
let mut current_model = None; let mut current_model = None;
let mut next_content_index: u32 = 0; let mut next_content_index: u32 = 0;
let mut has_sent_message_start = false; let mut has_sent_message_start = false;
// 某些上游 provider如 OpenRouter 的 kimi-k2.6)会在 tool_use 后发送多个
// 带 finish_reason 的 SSE chunk。Anthropic 协议要求每个消息流只能有一个
// message_delta重复会导致 Claude Code abort 连接。因此需要:
// 1) has_emitted_message_delta: 去重,只处理第一个 finish_reason
// 2) pending_message_delta: 缓存延迟到 [DONE] 发送,确保 usage 完整
let mut has_emitted_message_delta = false;
let mut pending_message_delta: Option<(Option<String>, Option<Value>)> = None;
let mut has_sent_message_stop = false;
let mut stream_ended_with_error = false;
let mut latest_usage: Option<Value> = None;
let mut current_non_tool_block_type: Option<&'static str> = None; let mut current_non_tool_block_type: Option<&'static str> = None;
let mut current_non_tool_block_index: Option<u32> = None; let mut current_non_tool_block_index: Option<u32> = None;
let mut tool_blocks_by_index: HashMap<usize, ToolBlockState> = HashMap::new(); let mut tool_blocks_by_index: HashMap<usize, ToolBlockState> = HashMap::new();
@@ -127,24 +154,53 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
if let Some(data) = strip_sse_field(l, "data") { if let Some(data) = strip_sse_field(l, "data") {
if data.trim() == "[DONE]" { if data.trim() == "[DONE]" {
log::debug!("[Claude/OpenRouter] <<< OpenAI SSE: [DONE]"); log::debug!("[Claude/OpenRouter] <<< OpenAI SSE: [DONE]");
// 流正常结束,发出缓存的 message_delta含完整 usage
if let Some((stop_reason, usage_json)) = pending_message_delta.take() {
let mut event = json!({
"type": "message_delta",
"delta": {
"stop_reason": stop_reason,
"stop_sequence": null
}
});
if let Some(uj) = usage_json {
event["usage"] = uj;
}
let sse_data = format!("event: message_delta\ndata: {}\n\n",
serde_json::to_string(&event).unwrap_or_default());
log::debug!("[Claude/OpenRouter] >>> Anthropic SSE: message_delta (from pending)");
yield Ok(Bytes::from(sse_data));
}
let event = json!({"type": "message_stop"}); let event = json!({"type": "message_stop"});
let sse_data = format!("event: message_stop\ndata: {}\n\n", let sse_data = format!("event: message_stop\ndata: {}\n\n",
serde_json::to_string(&event).unwrap_or_default()); serde_json::to_string(&event).unwrap_or_default());
log::debug!("[Claude/OpenRouter] >>> Anthropic SSE: message_stop"); log::debug!("[Claude/OpenRouter] >>> Anthropic SSE: message_stop");
yield Ok(Bytes::from(sse_data)); yield Ok(Bytes::from(sse_data));
has_sent_message_stop = true;
continue; continue;
} }
if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data) { if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data) {
log::debug!("[Claude/OpenRouter] <<< SSE chunk received"); log::debug!("[Claude/OpenRouter] <<< SSE chunk received");
if message_id.is_none() { if message_id.is_none() && !chunk.id.is_empty() {
message_id = Some(chunk.id.clone()); message_id = Some(chunk.id.clone());
} }
if current_model.is_none() { if current_model.is_none() && !chunk.model.is_empty() {
current_model = Some(chunk.model.clone()); current_model = Some(chunk.model.clone());
} }
let chunk_usage_json =
chunk.usage.as_ref().map(build_anthropic_usage_json);
if let Some(usage_json) = &chunk_usage_json {
latest_usage = Some(usage_json.clone());
if let Some((_, pending_usage)) = pending_message_delta.as_mut() {
*pending_usage = Some(usage_json.clone());
}
}
if let Some(choice) = chunk.choices.first() { if let Some(choice) = chunk.choices.first() {
if !has_sent_message_start { if !has_sent_message_start {
// Build usage with cache tokens if available from first chunk // Build usage with cache tokens if available from first chunk
@@ -420,8 +476,24 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
} }
} }
// 处理 finish_reason // 处理 finish_reason
// 注意OpenRouter 某些 provider 会发送多个带 finish_reason 的 chunk
// (第一个 usage 为 null后续才补全。此处只做缓存不立即发送
// 等到 [DONE] 或流末尾再统一发出,确保 usage 完整且只发一次。
if let Some(finish_reason) = &choice.finish_reason { if let Some(finish_reason) = &choice.finish_reason {
let stop_reason = map_stop_reason(Some(finish_reason));
let usage_json =
chunk_usage_json.clone().or_else(|| latest_usage.clone());
if has_emitted_message_delta {
// 更新缓存的 message_delta usage如果有更完整的 usage
if let (Some((_, ref mut usage)), Some(uj)) = (&mut pending_message_delta, usage_json) {
*usage = Some(uj);
}
continue;
}
has_emitted_message_delta = true;
if let Some(index) = current_non_tool_block_index.take() { if let Some(index) = current_non_tool_block_index.take() {
let event = json!({ let event = json!({
"type": "content_block_stop", "type": "content_block_stop",
@@ -511,32 +583,8 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
open_tool_block_indices.clear(); open_tool_block_indices.clear();
} }
let stop_reason = map_stop_reason(Some(finish_reason)); // 缓存 message_delta等到 [DONE] 时发送(以便收集完整的 usage
// Build usage with cache token fields pending_message_delta = Some((stop_reason, usage_json));
let usage_json = chunk.usage.as_ref().map(|u| {
let mut uj = json!({
"input_tokens": u.prompt_tokens,
"output_tokens": u.completion_tokens
});
if let Some(cached) = extract_cache_read_tokens(u) {
uj["cache_read_input_tokens"] = json!(cached);
}
if let Some(created) = u.cache_creation_input_tokens {
uj["cache_creation_input_tokens"] = json!(created);
}
uj
});
let event = json!({
"type": "message_delta",
"delta": {
"stop_reason": stop_reason,
"stop_sequence": null
},
"usage": usage_json
});
let sse_data = format!("event: message_delta\ndata: {}\n\n",
serde_json::to_string(&event).unwrap_or_default());
yield Ok(Bytes::from(sse_data));
} }
} }
} }
@@ -546,6 +594,7 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
} }
Err(e) => { Err(e) => {
log::error!("Stream error: {e}"); log::error!("Stream error: {e}");
stream_ended_with_error = true;
let error_event = json!({ let error_event = json!({
"type": "error", "type": "error",
"error": { "error": {
@@ -560,6 +609,40 @@ pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
} }
} }
} }
// 流自然结束但未收到 [DONE] 时,确保发送缓存的 message_delta 和 message_stop。
// 若上游已显式报错,则只保留 error 事件,避免把失败伪装成成功完成。
if !stream_ended_with_error {
let emitted_pending_message_delta = if let Some((stop_reason, usage_json)) =
pending_message_delta.take()
{
let mut event = json!({
"type": "message_delta",
"delta": {
"stop_reason": stop_reason,
"stop_sequence": null
}
});
if let Some(uj) = usage_json {
event["usage"] = uj;
}
let sse_data = format!("event: message_delta\ndata: {}\n\n",
serde_json::to_string(&event).unwrap_or_default());
log::debug!("[Claude/OpenRouter] >>> Anthropic SSE: message_delta (at stream end)");
yield Ok(Bytes::from(sse_data));
true
} else {
false
};
if emitted_pending_message_delta && !has_sent_message_stop {
let event = json!({"type": "message_stop"});
let sse_data = format!("event: message_stop\ndata: {}\n\n",
serde_json::to_string(&event).unwrap_or_default());
log::debug!("[Claude/OpenRouter] >>> Anthropic SSE: message_stop (at stream end)");
yield Ok(Bytes::from(sse_data));
}
}
} }
} }
@@ -602,6 +685,32 @@ mod tests {
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
async fn collect_anthropic_events(input: &str) -> Vec<Value> {
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
.into_iter()
.map(|chunk| String::from_utf8_lossy(chunk.unwrap().as_ref()).to_string())
.collect::<String>();
merged
.split("\n\n")
.filter_map(|block| {
let data = block
.lines()
.find_map(|line| strip_sse_field(line, "data"))?;
serde_json::from_str::<Value>(data).ok()
})
.collect()
}
fn event_type(event: &Value) -> Option<&str> {
event.get("type").and_then(|v| v.as_str())
}
#[test] #[test]
fn test_map_stop_reason_legacy_and_filtered_values() { fn test_map_stop_reason_legacy_and_filtered_values() {
assert_eq!( assert_eq!(
@@ -818,4 +927,174 @@ mod tests {
"output must not contain U+FFFD replacement characters" "output must not contain U+FFFD replacement characters"
); );
} }
#[tokio::test]
async fn test_duplicate_finish_reason_emits_only_one_message_delta() {
// Simulates OpenRouter behavior where two chunks carry finish_reason:
// first with null usage, second with populated usage.
let input = concat!(
"data: {\"id\":\"chatcmpl_dup\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: {\"id\":\"chatcmpl_dup\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5}}\n\n",
"data: [DONE]\n\n"
);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
.into_iter()
.map(|chunk| String::from_utf8_lossy(chunk.unwrap().as_ref()).to_string())
.collect::<String>();
let events: Vec<Value> = merged
.split("\n\n")
.filter_map(|block| {
let data = block
.lines()
.find_map(|line| strip_sse_field(line, "data"))?;
serde_json::from_str::<Value>(data).ok()
})
.collect();
let message_deltas: Vec<&Value> = events
.iter()
.filter(|e| e.get("type").and_then(|v| v.as_str()) == Some("message_delta"))
.collect();
assert_eq!(
message_deltas.len(),
1,
"duplicate finish_reason chunks must produce exactly one message_delta, got {}: {:?}",
message_deltas.len(),
message_deltas
);
assert_eq!(message_deltas[0]["usage"]["input_tokens"], 10);
assert_eq!(message_deltas[0]["usage"]["output_tokens"], 5);
let message_stops = events
.iter()
.filter(|e| e.get("type").and_then(|v| v.as_str()) == Some("message_stop"))
.count();
assert_eq!(message_stops, 1, "message_stop must only be emitted once");
}
#[tokio::test]
async fn test_usage_only_chunk_after_finish_reason_updates_message_delta_usage() {
let input = concat!(
"data: {\"id\":\"chatcmpl_split\",\"model\":\"glm-5.1\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"tool-0924\",\"type\":\"function\",\"function\":{\"name\":\"Bash\",\"arguments\":\"{\\\"command\\\":\\\"pwd\\\"}\"}}]}}]}\n\n",
"data: {\"id\":\"chatcmpl_split\",\"model\":\"glm-5.1\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: {\"choices\":[],\"usage\":{\"prompt_tokens\":13312,\"completion_tokens\":79,\"prompt_tokens_details\":{\"cached_tokens\":100}}}\n\n",
"data: [DONE]\n\n"
);
let events = collect_anthropic_events(input).await;
let message_deltas: Vec<&Value> = events
.iter()
.filter(|event| event_type(event) == Some("message_delta"))
.collect();
let message_stops = events
.iter()
.filter(|event| event_type(event) == Some("message_stop"))
.count();
assert_eq!(message_deltas.len(), 1);
assert_eq!(message_stops, 1);
let message_delta = message_deltas[0];
assert_eq!(
message_delta
.pointer("/delta/stop_reason")
.and_then(|v| v.as_str()),
Some("tool_use")
);
assert_eq!(
message_delta
.pointer("/usage/input_tokens")
.and_then(|v| v.as_u64()),
Some(13312)
);
assert_eq!(
message_delta
.pointer("/usage/output_tokens")
.and_then(|v| v.as_u64()),
Some(79)
);
assert_eq!(
message_delta
.pointer("/usage/cache_read_input_tokens")
.and_then(|v| v.as_u64()),
Some(100)
);
}
#[tokio::test]
async fn test_streaming_finalizes_after_finish_when_done_is_missing() {
let input = concat!(
"data: {\"id\":\"chatcmpl_no_done\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n",
"data: {\"id\":\"chatcmpl_no_done\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n"
);
let events = collect_anthropic_events(input).await;
assert!(events.iter().any(|event| {
event_type(event) == Some("message_delta")
&& event.pointer("/delta/stop_reason").and_then(|v| v.as_str()) == Some("end_turn")
}));
assert_eq!(
events.last().and_then(|event| event_type(event)),
Some("message_stop")
);
}
#[tokio::test]
async fn test_stream_end_without_finish_reason_does_not_emit_success_terminal_events() {
let input = "data: {\"id\":\"chatcmpl_truncated\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\n\n";
let events = collect_anthropic_events(input).await;
assert!(!events
.iter()
.any(|event| event_type(event) == Some("message_delta")));
assert!(!events
.iter()
.any(|event| event_type(event) == Some("message_stop")));
}
#[tokio::test]
async fn test_stream_error_does_not_emit_success_terminal_events() {
let upstream = stream::iter(vec![Err::<Bytes, _>(std::io::Error::other(
"upstream disconnected",
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
.into_iter()
.map(|chunk| String::from_utf8_lossy(chunk.unwrap().as_ref()).to_string())
.collect::<String>();
let events: Vec<Value> = merged
.split("\n\n")
.filter_map(|block| {
let data = block
.lines()
.find_map(|line| strip_sse_field(line, "data"))?;
serde_json::from_str::<Value>(data).ok()
})
.collect();
assert!(events
.iter()
.any(|e| e.get("type").and_then(|v| v.as_str()) == Some("error")));
assert!(!events
.iter()
.any(|e| e.get("type").and_then(|v| v.as_str()) == Some("message_delta")));
assert!(!events
.iter()
.any(|e| e.get("type").and_then(|v| v.as_str()) == Some("message_stop")));
}
} }