Merge remote-tracking branch 'origin/main' into feat/usage-dashboard-refinement

This commit is contained in:
Jason
2026-04-15 11:36:52 +08:00
3 changed files with 130 additions and 10 deletions

View File

@@ -20,7 +20,8 @@ use super::{
},
response_processor::{
create_logged_passthrough_stream, process_response, read_decoded_body,
strip_entity_headers_for_rebuilt_body, SseUsageCollector,
strip_entity_headers_for_rebuilt_body, strip_hop_by_hop_response_headers,
SseUsageCollector,
},
server::ProxyState,
types::*,
@@ -216,10 +217,6 @@ async fn handle_claude_transform(
"Cache-Control",
axum::http::HeaderValue::from_static("no-cache"),
);
headers.insert(
"Connection",
axum::http::HeaderValue::from_static("keep-alive"),
);
let body = axum::body::Body::from_stream(logged_stream);
return Ok((headers, body).into_response());
@@ -287,6 +284,7 @@ async fn handle_claude_transform(
// 构建响应
let mut builder = axum::response::Response::builder().status(status);
strip_entity_headers_for_rebuilt_body(&mut response_headers);
strip_hop_by_hop_response_headers(&mut response_headers);
for (key, value) in response_headers.iter() {
builder = builder.header(key, value);

View File

@@ -11,7 +11,7 @@ use super::{
usage::parser::TokenUsage,
ProxyError,
};
use axum::http::header::HeaderMap;
use axum::http::{header::HeaderMap, HeaderName};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
@@ -68,6 +68,41 @@ fn get_content_encoding(headers: &HeaderMap) -> Option<String> {
.filter(|s| !s.is_empty() && s != "identity")
}
/// RFC 2616 / RFC 7230 中定义的不应被代理继续转发的响应头。
const HOP_BY_HOP_RESPONSE_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"proxy-connection",
"te",
"trailer",
"trailers",
"transfer-encoding",
"upgrade",
];
/// 移除响应侧 hop-by-hop 头,以及 `Connection` 中点名的扩展头。
pub(crate) fn strip_hop_by_hop_response_headers(headers: &mut HeaderMap) {
let connection_listed_headers: Vec<HeaderName> = headers
.get_all(axum::http::header::CONNECTION)
.iter()
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(','))
.map(str::trim)
.filter(|name| !name.is_empty())
.filter_map(|name| HeaderName::from_bytes(name.as_bytes()).ok())
.collect();
for name in HOP_BY_HOP_RESPONSE_HEADERS {
headers.remove(*name);
}
for name in connection_listed_headers {
headers.remove(name);
}
}
/// 移除在重建响应体后会失真的实体头。
pub(crate) fn strip_entity_headers_for_rebuilt_body(headers: &mut HeaderMap) {
headers.remove(axum::http::header::CONTENT_ENCODING);
@@ -163,10 +198,13 @@ pub async fn handle_streaming(
);
}
let mut response_headers = response.headers().clone();
strip_hop_by_hop_response_headers(&mut response_headers);
let mut builder = axum::response::Response::builder().status(status);
// 复制响应头
for (key, value) in response.headers() {
for (key, value) in &response_headers {
builder = builder.header(key, value);
}
@@ -207,8 +245,9 @@ pub async fn handle_non_streaming(
} else {
Duration::ZERO
};
let (response_headers, status, body_bytes) =
let (mut response_headers, status, body_bytes) =
read_decoded_body(response, ctx.tag, body_timeout).await?;
strip_hop_by_hop_response_headers(&mut response_headers);
log::debug!(
"[{}] 上游响应体内容: {}",
@@ -715,6 +754,90 @@ mod tests {
assert_eq!(super::strip_sse_field("id:1", "data"), None);
}
#[test]
fn test_strip_hop_by_hop_response_headers_removes_standard_headers() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("keep-alive"),
);
headers.insert(
axum::http::header::HeaderName::from_static("keep-alive"),
axum::http::HeaderValue::from_static("timeout=5"),
);
headers.insert(
axum::http::header::TRANSFER_ENCODING,
axum::http::HeaderValue::from_static("chunked"),
);
headers.insert(
axum::http::header::HeaderName::from_static("proxy-connection"),
axum::http::HeaderValue::from_static("keep-alive"),
);
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
headers.insert(
axum::http::header::CONTENT_LENGTH,
axum::http::HeaderValue::from_static("12"),
);
strip_hop_by_hop_response_headers(&mut headers);
assert!(!headers.contains_key(axum::http::header::CONNECTION));
assert!(!headers.contains_key("keep-alive"));
assert!(!headers.contains_key(axum::http::header::TRANSFER_ENCODING));
assert!(!headers.contains_key("proxy-connection"));
assert_eq!(
headers.get(axum::http::header::CONTENT_TYPE),
Some(&axum::http::HeaderValue::from_static("application/json"))
);
assert_eq!(
headers.get(axum::http::header::CONTENT_LENGTH),
Some(&axum::http::HeaderValue::from_static("12"))
);
}
#[test]
fn test_strip_hop_by_hop_response_headers_removes_connection_listed_extensions() {
let mut headers = HeaderMap::new();
headers.append(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("x-trace-hop, x-debug-hop"),
);
headers.append(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("upgrade"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-trace-hop"),
axum::http::HeaderValue::from_static("trace"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-debug-hop"),
axum::http::HeaderValue::from_static("debug"),
);
headers.insert(
axum::http::header::UPGRADE,
axum::http::HeaderValue::from_static("websocket"),
);
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/event-stream"),
);
strip_hop_by_hop_response_headers(&mut headers);
assert!(!headers.contains_key(axum::http::header::CONNECTION));
assert!(!headers.contains_key("x-trace-hop"));
assert!(!headers.contains_key("x-debug-hop"));
assert!(!headers.contains_key(axum::http::header::UPGRADE));
assert_eq!(
headers.get(axum::http::header::CONTENT_TYPE),
Some(&axum::http::HeaderValue::from_static("text/event-stream"))
);
}
fn build_state(db: Arc<Database>) -> ProxyState {
ProxyState {
db: db.clone(),

View File

@@ -428,8 +428,7 @@ impl StreamCheckService {
.header("x-stainless-retry-count", "0")
.header("x-stainless-timeout", "600")
// Other headers
.header("sec-fetch-mode", "cors")
.header("connection", "keep-alive");
.header("sec-fetch-mode", "cors");
}
// 供应商自定义 headers 最后追加,允许覆盖内置默认值(例如 user-agent