Files
cc-switch/src-tauri/src/proxy/server.rs
2026-04-13 12:26:19 +08:00

351 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! HTTP代理服务器
//!
//! 基于Axum的HTTP服务器处理代理请求
//!
//! Uses a manual hyper HTTP/1.1 accept loop with `preserve_header_case(true)` so
//! that the original header-name casing from the CLI client is captured in a
//! `HeaderCaseMap` extension. This map is later forwarded to the upstream via
//! the hyper-based HTTP client, producing wire-level header casing identical to
//! a direct (non-proxied) CLI request.
use super::{
failover_switch::FailoverSwitchManager, handlers, log_codes::srv as log_srv,
provider_router::ProviderRouter, types::*, ProxyError,
};
use crate::database::Database;
use axum::{
extract::DefaultBodyLimit,
routing::{get, post},
Router,
};
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{oneshot, RwLock};
use tokio::task::JoinHandle;
/// 代理服务器状态(共享)
#[derive(Clone)]
pub struct ProxyState {
pub db: Arc<Database>,
pub config: Arc<RwLock<ProxyConfig>>,
pub status: Arc<RwLock<ProxyStatus>>,
pub start_time: Arc<RwLock<Option<std::time::Instant>>>,
/// 每个应用类型当前使用的 provider (app_type -> (provider_id, provider_name))
pub current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
/// 共享的 ProviderRouter持有熔断器状态跨请求保持
pub provider_router: Arc<ProviderRouter>,
/// AppHandle用于发射事件和更新托盘菜单
pub app_handle: Option<tauri::AppHandle>,
/// 故障转移切换管理器
pub failover_manager: Arc<FailoverSwitchManager>,
}
/// 代理HTTP服务器
pub struct ProxyServer {
config: ProxyConfig,
state: ProxyState,
shutdown_tx: Arc<RwLock<Option<oneshot::Sender<()>>>>,
/// 服务器任务句柄,用于等待服务器实际关闭
server_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
}
impl ProxyServer {
pub fn new(
config: ProxyConfig,
db: Arc<Database>,
app_handle: Option<tauri::AppHandle>,
) -> Self {
// 创建共享的 ProviderRouter熔断器状态将跨所有请求保持
let provider_router = Arc::new(ProviderRouter::new(db.clone()));
// 创建故障转移切换管理器
let failover_manager = Arc::new(FailoverSwitchManager::new(db.clone()));
let state = ProxyState {
db,
config: Arc::new(RwLock::new(config.clone())),
status: Arc::new(RwLock::new(ProxyStatus::default())),
start_time: Arc::new(RwLock::new(None)),
current_providers: Arc::new(RwLock::new(std::collections::HashMap::new())),
provider_router,
app_handle,
failover_manager,
};
Self {
config,
state,
shutdown_tx: Arc::new(RwLock::new(None)),
server_handle: Arc::new(RwLock::new(None)),
}
}
pub async fn start(&self) -> Result<ProxyServerInfo, ProxyError> {
// 检查是否已在运行
if self.shutdown_tx.read().await.is_some() {
return Err(ProxyError::AlreadyRunning);
}
let addr: SocketAddr =
format!("{}:{}", self.config.listen_address, self.config.listen_port)
.parse()
.map_err(|e| ProxyError::BindFailed(format!("无效的地址: {e}")))?;
// 创建关闭通道
let (shutdown_tx, shutdown_rx) = oneshot::channel();
// 构建路由
let app = self.build_router();
// 绑定监听器
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| ProxyError::BindFailed(e.to_string()))?;
log::info!("[{}] 代理服务器启动于 {addr}", log_srv::STARTED);
// 更新全局代理端口,用于系统代理检测
crate::proxy::http_client::set_proxy_port(self.config.listen_port);
// 保存关闭句柄
*self.shutdown_tx.write().await = Some(shutdown_tx);
// 更新状态
let mut status = self.state.status.write().await;
status.running = true;
status.address = self.config.listen_address.clone();
status.port = self.config.listen_port;
drop(status);
// 记录启动时间
*self.state.start_time.write().await = Some(std::time::Instant::now());
// 启动服务器 — 使用手动 hyper HTTP/1.1 accept loop
// 开启 preserve_header_case 以捕获客户端请求头的原始大小写
let state = self.state.clone();
let handle = tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _remote_addr) = match result {
Ok(v) => v,
Err(e) => {
log::error!("[{SRV}] accept 失败: {e}", SRV = log_srv::ACCEPT_ERR);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
continue;
}
};
let app = app.clone();
tokio::spawn(async move {
// Peek raw TCP bytes to capture original header casing
// before hyper parses (and lowercases) the header names.
let original_cases = {
let mut peek_buf = vec![0u8; 8192];
match stream.peek(&mut peek_buf).await {
Ok(n) => {
let cases = super::hyper_client::OriginalHeaderCases::from_raw_bytes(&peek_buf[..n]);
log::debug!(
"[ProxyServer] Peeked {} bytes, captured {} header casings",
n, cases.cases.len()
);
cases
}
Err(e) => {
log::debug!("[ProxyServer] peek failed (non-fatal): {e}");
super::hyper_client::OriginalHeaderCases::default()
}
}
};
// service_fn 将 axum Routertower::Service桥接到 hyper
let service = hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let mut router = app.clone();
let cases = original_cases.clone();
async move {
// 将 hyper::body::Incoming 转为 axum::body::Body保留 extensions
let (mut parts, body) = req.into_parts();
// Insert our own header case map alongside hyper's internal one
parts.extensions.insert(cases);
let body = axum::body::Body::new(body);
let axum_req = http::Request::from_parts(parts, body);
<Router as tower::Service<http::Request<axum::body::Body>>>::call(&mut router, axum_req).await
}
});
if let Err(e) = hyper::server::conn::http1::Builder::new()
.preserve_header_case(true)
.serve_connection(TokioIo::new(stream), service)
.await
{
// Connection reset / broken pipe 等在代理场景下很常见debug 级别
log::debug!("[{SRV}] connection error: {e}", SRV = log_srv::CONN_ERR);
}
});
}
_ = &mut shutdown_rx => {
break;
}
}
}
// 服务器停止后更新状态
state.status.write().await.running = false;
*state.start_time.write().await = None;
});
// 保存服务器任务句柄
*self.server_handle.write().await = Some(handle);
Ok(ProxyServerInfo {
address: self.config.listen_address.clone(),
port: self.config.listen_port,
started_at: chrono::Utc::now().to_rfc3339(),
})
}
pub async fn stop(&self) -> Result<(), ProxyError> {
// 1. 发送关闭信号
if let Some(tx) = self.shutdown_tx.write().await.take() {
let _ = tx.send(());
} else {
return Err(ProxyError::NotRunning);
}
// 2. 等待服务器任务结束(带 5 秒超时保护)
if let Some(handle) = self.server_handle.write().await.take() {
match tokio::time::timeout(std::time::Duration::from_secs(5), handle).await {
Ok(Ok(())) => {
log::info!("[{}] 代理服务器已完全停止", log_srv::STOPPED);
Ok(())
}
Ok(Err(e)) => {
log::warn!("[{}] 代理服务器任务异常终止: {e}", log_srv::TASK_ERROR);
Err(ProxyError::StopFailed(e.to_string()))
}
Err(_) => {
log::warn!(
"[{}] 代理服务器停止超时5秒强制继续",
log_srv::STOP_TIMEOUT
);
Err(ProxyError::StopTimeout)
}
}
} else {
Ok(())
}
}
pub async fn get_status(&self) -> ProxyStatus {
let mut status = self.state.status.read().await.clone();
// 计算运行时间
if let Some(start) = *self.state.start_time.read().await {
status.uptime_seconds = start.elapsed().as_secs();
}
// 从 current_providers HashMap 获取每个应用类型当前正在使用的 provider
let current_providers = self.state.current_providers.read().await;
status.active_targets = current_providers
.iter()
.map(|(app_type, (provider_id, provider_name))| ActiveTarget {
app_type: app_type.clone(),
provider_id: provider_id.clone(),
provider_name: provider_name.clone(),
})
.collect();
status
}
/// 更新某个应用类型当前“目标供应商”(用于 UI 展示 active_targets
///
/// 注意:这不代表该供应商一定已经处理过请求,而是用于“热切换/启用故障转移立即切 P1”
/// 等场景下,让 UI 能立刻反映最新目标。
pub async fn set_active_target(&self, app_type: &str, provider_id: &str, provider_name: &str) {
let mut current_providers = self.state.current_providers.write().await;
current_providers.insert(
app_type.to_string(),
(provider_id.to_string(), provider_name.to_string()),
);
}
fn build_router(&self) -> Router {
Router::new()
// 健康检查
.route("/health", get(handlers::health_check))
.route("/status", get(handlers::get_status))
// Claude API (支持带前缀和不带前缀两种格式)
.route("/v1/messages", post(handlers::handle_messages))
.route("/claude/v1/messages", post(handlers::handle_messages))
// OpenAI Chat Completions API (Codex CLI支持带前缀和不带前缀)
.route("/chat/completions", post(handlers::handle_chat_completions))
.route(
"/v1/chat/completions",
post(handlers::handle_chat_completions),
)
.route(
"/v1/v1/chat/completions",
post(handlers::handle_chat_completions),
)
.route(
"/codex/v1/chat/completions",
post(handlers::handle_chat_completions),
)
// OpenAI Responses API (Codex CLI支持带前缀和不带前缀)
.route("/responses", post(handlers::handle_responses))
.route("/v1/responses", post(handlers::handle_responses))
.route("/v1/v1/responses", post(handlers::handle_responses))
.route("/codex/v1/responses", post(handlers::handle_responses))
// OpenAI Responses Compact API (Codex CLI 远程压缩,透传)
.route(
"/responses/compact",
post(handlers::handle_responses_compact),
)
.route(
"/v1/responses/compact",
post(handlers::handle_responses_compact),
)
.route(
"/v1/v1/responses/compact",
post(handlers::handle_responses_compact),
)
.route(
"/codex/v1/responses/compact",
post(handlers::handle_responses_compact),
)
// Gemini API (支持带前缀和不带前缀)
.route("/v1beta/*path", post(handlers::handle_gemini))
.route("/gemini/v1beta/*path", post(handlers::handle_gemini))
// 提高默认请求体大小限制(避免 413 Payload Too Large
.layer(DefaultBodyLimit::max(200 * 1024 * 1024))
.with_state(self.state.clone())
}
/// 在不重启服务的情况下更新运行时配置
pub async fn apply_runtime_config(&self, config: &ProxyConfig) {
*self.state.config.write().await = config.clone();
}
/// 热更新熔断器配置
///
/// 将新配置应用到所有已创建的熔断器实例
pub async fn update_circuit_breaker_configs(
&self,
config: super::circuit_breaker::CircuitBreakerConfig,
) {
self.state.provider_router.update_all_configs(config).await;
}
/// 重置指定 Provider 的熔断器
pub async fn reset_provider_circuit_breaker(&self, provider_id: &str, app_type: &str) {
self.state
.provider_router
.reset_provider_breaker(provider_id, app_type)
.await;
}
}