diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 3a2bd4ca..c203cb8b 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -20,7 +20,8 @@ use super::{ }, response_processor::{ create_logged_passthrough_stream, process_response, read_decoded_body, - strip_entity_headers_for_rebuilt_body, SseUsageCollector, + strip_entity_headers_for_rebuilt_body, strip_hop_by_hop_response_headers, + SseUsageCollector, }, server::ProxyState, types::*, @@ -216,10 +217,6 @@ async fn handle_claude_transform( "Cache-Control", axum::http::HeaderValue::from_static("no-cache"), ); - headers.insert( - "Connection", - axum::http::HeaderValue::from_static("keep-alive"), - ); let body = axum::body::Body::from_stream(logged_stream); return Ok((headers, body).into_response()); @@ -287,6 +284,7 @@ async fn handle_claude_transform( // 构建响应 let mut builder = axum::response::Response::builder().status(status); strip_entity_headers_for_rebuilt_body(&mut response_headers); + strip_hop_by_hop_response_headers(&mut response_headers); for (key, value) in response_headers.iter() { builder = builder.header(key, value); diff --git a/src-tauri/src/proxy/response_processor.rs b/src-tauri/src/proxy/response_processor.rs index 1a185f93..3008bb5f 100644 --- a/src-tauri/src/proxy/response_processor.rs +++ b/src-tauri/src/proxy/response_processor.rs @@ -11,7 +11,7 @@ use super::{ usage::parser::TokenUsage, ProxyError, }; -use axum::http::header::HeaderMap; +use axum::http::{header::HeaderMap, HeaderName}; use axum::response::{IntoResponse, Response}; use bytes::Bytes; use futures::stream::{Stream, StreamExt}; @@ -68,6 +68,41 @@ fn get_content_encoding(headers: &HeaderMap) -> Option { .filter(|s| !s.is_empty() && s != "identity") } +/// RFC 2616 / RFC 7230 中定义的不应被代理继续转发的响应头。 +const HOP_BY_HOP_RESPONSE_HEADERS: &[&str] = &[ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "proxy-connection", + "te", + "trailer", + "trailers", + "transfer-encoding", + "upgrade", +]; + +/// 移除响应侧 hop-by-hop 头,以及 `Connection` 中点名的扩展头。 +pub(crate) fn strip_hop_by_hop_response_headers(headers: &mut HeaderMap) { + let connection_listed_headers: Vec = headers + .get_all(axum::http::header::CONNECTION) + .iter() + .filter_map(|value| value.to_str().ok()) + .flat_map(|value| value.split(',')) + .map(str::trim) + .filter(|name| !name.is_empty()) + .filter_map(|name| HeaderName::from_bytes(name.as_bytes()).ok()) + .collect(); + + for name in HOP_BY_HOP_RESPONSE_HEADERS { + headers.remove(*name); + } + + for name in connection_listed_headers { + headers.remove(name); + } +} + /// 移除在重建响应体后会失真的实体头。 pub(crate) fn strip_entity_headers_for_rebuilt_body(headers: &mut HeaderMap) { headers.remove(axum::http::header::CONTENT_ENCODING); @@ -163,10 +198,13 @@ pub async fn handle_streaming( ); } + let mut response_headers = response.headers().clone(); + strip_hop_by_hop_response_headers(&mut response_headers); + let mut builder = axum::response::Response::builder().status(status); // 复制响应头 - for (key, value) in response.headers() { + for (key, value) in &response_headers { builder = builder.header(key, value); } @@ -207,8 +245,9 @@ pub async fn handle_non_streaming( } else { Duration::ZERO }; - let (response_headers, status, body_bytes) = + let (mut response_headers, status, body_bytes) = read_decoded_body(response, ctx.tag, body_timeout).await?; + strip_hop_by_hop_response_headers(&mut response_headers); log::debug!( "[{}] 上游响应体内容: {}", @@ -715,6 +754,90 @@ mod tests { assert_eq!(super::strip_sse_field("id:1", "data"), None); } + #[test] + fn test_strip_hop_by_hop_response_headers_removes_standard_headers() { + let mut headers = HeaderMap::new(); + headers.insert( + axum::http::header::CONNECTION, + axum::http::HeaderValue::from_static("keep-alive"), + ); + headers.insert( + axum::http::header::HeaderName::from_static("keep-alive"), + axum::http::HeaderValue::from_static("timeout=5"), + ); + headers.insert( + axum::http::header::TRANSFER_ENCODING, + axum::http::HeaderValue::from_static("chunked"), + ); + headers.insert( + axum::http::header::HeaderName::from_static("proxy-connection"), + axum::http::HeaderValue::from_static("keep-alive"), + ); + headers.insert( + axum::http::header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("application/json"), + ); + headers.insert( + axum::http::header::CONTENT_LENGTH, + axum::http::HeaderValue::from_static("12"), + ); + + strip_hop_by_hop_response_headers(&mut headers); + + assert!(!headers.contains_key(axum::http::header::CONNECTION)); + assert!(!headers.contains_key("keep-alive")); + assert!(!headers.contains_key(axum::http::header::TRANSFER_ENCODING)); + assert!(!headers.contains_key("proxy-connection")); + assert_eq!( + headers.get(axum::http::header::CONTENT_TYPE), + Some(&axum::http::HeaderValue::from_static("application/json")) + ); + assert_eq!( + headers.get(axum::http::header::CONTENT_LENGTH), + Some(&axum::http::HeaderValue::from_static("12")) + ); + } + + #[test] + fn test_strip_hop_by_hop_response_headers_removes_connection_listed_extensions() { + let mut headers = HeaderMap::new(); + headers.append( + axum::http::header::CONNECTION, + axum::http::HeaderValue::from_static("x-trace-hop, x-debug-hop"), + ); + headers.append( + axum::http::header::CONNECTION, + axum::http::HeaderValue::from_static("upgrade"), + ); + headers.insert( + axum::http::header::HeaderName::from_static("x-trace-hop"), + axum::http::HeaderValue::from_static("trace"), + ); + headers.insert( + axum::http::header::HeaderName::from_static("x-debug-hop"), + axum::http::HeaderValue::from_static("debug"), + ); + headers.insert( + axum::http::header::UPGRADE, + axum::http::HeaderValue::from_static("websocket"), + ); + headers.insert( + axum::http::header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("text/event-stream"), + ); + + strip_hop_by_hop_response_headers(&mut headers); + + assert!(!headers.contains_key(axum::http::header::CONNECTION)); + assert!(!headers.contains_key("x-trace-hop")); + assert!(!headers.contains_key("x-debug-hop")); + assert!(!headers.contains_key(axum::http::header::UPGRADE)); + assert_eq!( + headers.get(axum::http::header::CONTENT_TYPE), + Some(&axum::http::HeaderValue::from_static("text/event-stream")) + ); + } + fn build_state(db: Arc) -> ProxyState { ProxyState { db: db.clone(), diff --git a/src-tauri/src/services/stream_check.rs b/src-tauri/src/services/stream_check.rs index ef545fdf..ffe679e6 100644 --- a/src-tauri/src/services/stream_check.rs +++ b/src-tauri/src/services/stream_check.rs @@ -428,8 +428,7 @@ impl StreamCheckService { .header("x-stainless-retry-count", "0") .header("x-stainless-timeout", "600") // Other headers - .header("sec-fetch-mode", "cors") - .header("connection", "keep-alive"); + .header("sec-fetch-mode", "cors"); } // 供应商自定义 headers 最后追加,允许覆盖内置默认值(例如 user-agent)