diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index ff0024d0..1ae679ef 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -107,6 +107,26 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arboard" +version = "3.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0348a1c054491f4bfe6ab86a7b6ab1e44e45d899005de92f58b3df180b36ddaf" +dependencies = [ + "clipboard-win", + "image", + "log", + "objc2 0.6.4", + "objc2-app-kit 0.3.2", + "objc2-core-foundation", + "objc2-core-graphics", + "objc2-foundation 0.3.2", + "parking_lot", + "percent-encoding", + "windows-sys 0.60.2", + "x11rb", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -718,6 +738,7 @@ name = "cc-switch" version = "3.12.3" dependencies = [ "anyhow", + "arboard", "async-stream", "auto-launch", "axum", @@ -843,6 +864,15 @@ dependencies = [ "inout", ] +[[package]] +name = "clipboard-win" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" +dependencies = [ + "error-code", +] + [[package]] name = "cmake" version = "0.1.57" @@ -1445,6 +1475,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "error-code" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" + [[package]] name = "event-listener" version = "5.4.1" @@ -1484,6 +1520,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fax" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05de7d48f37cd6730705cbca900770cab77a89f413d23e100ad7fad7795a0ab" +dependencies = [ + "fax_derive", +] + +[[package]] +name = "fax_derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "fdeflate" version = "0.3.7" @@ -1849,6 +1905,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "gethostname" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8" +dependencies = [ + "rustix", + "windows-link 0.2.1", +] + [[package]] name = "getrandom" version = "0.1.16" @@ -2067,6 +2133,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -2457,6 +2534,7 @@ dependencies = [ "moxcms", "num-traits", "png 0.18.1", + "tiff", ] [[package]] @@ -3088,6 +3166,7 @@ dependencies = [ "block2 0.6.2", "objc2 0.6.4", "objc2-core-foundation", + "objc2-core-graphics", "objc2-foundation 0.3.2", ] @@ -3894,6 +3973,12 @@ version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5a041e753da8b807c9255f28de81879c78c876392ff2469cde94799b2896b9d" +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quick-xml" version = "0.38.4" @@ -5766,6 +5851,20 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tiff" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b63feaf3343d35b6ca4d50483f94843803b0f51634937cc2ec519fc32232bc52" +dependencies = [ + "fax", + "flate2", + "half", + "quick-error", + "weezl", + "zune-jpeg", +] + [[package]] name = "time" version = "0.3.47" @@ -6660,6 +6759,12 @@ dependencies = [ "windows-core 0.61.2", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + [[package]] name = "winapi" version = "0.3.9" @@ -7397,6 +7502,23 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "x11rb" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414" +dependencies = [ + "gethostname", + "rustix", + "x11rb-protocol", +] + +[[package]] +name = "x11rb-protocol" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd" + [[package]] name = "xattr" version = "1.6.1" @@ -7682,6 +7804,21 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-jpeg" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" +dependencies = [ + "zune-core", +] + [[package]] name = "zvariant" version = "5.10.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index f4ba03c6..9ac16458 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -39,6 +39,7 @@ dirs = "5.0" toml = "0.8" toml_edit = "0.22" reqwest = { version = "0.12", features = ["rustls-tls", "json", "stream", "socks"] } +arboard = "3.6" flate2 = "1" brotli = "7" tokio = { version = "1", features = ["macros", "rt-multi-thread", "time", "sync"] } diff --git a/src-tauri/src/commands/misc.rs b/src-tauri/src/commands/misc.rs index 3bb33798..5001e1f3 100644 --- a/src-tauri/src/commands/misc.rs +++ b/src-tauri/src/commands/misc.rs @@ -34,6 +34,22 @@ pub async fn open_external(app: AppHandle, url: String) -> Result Ok(true) } +#[tauri::command] +pub async fn copy_text_to_clipboard(text: String) -> Result { + // Use spawn_blocking to avoid blocking the async runtime + // Clipboard access can block on some platforms and may have thread/loop constraints + tokio::task::spawn_blocking(move || { + let mut clipboard = + arboard::Clipboard::new().map_err(|e| format!("访问系统剪贴板失败: {e}"))?; + clipboard + .set_text(text) + .map_err(|e| format!("写入系统剪贴板失败: {e}"))?; + Ok(true) + }) + .await + .map_err(|e| format!("剪贴板任务执行失败: {e}"))? +} + /// 检查更新 #[tauri::command] pub async fn check_for_updates(handle: AppHandle) -> Result { diff --git a/src-tauri/src/commands/stream_check.rs b/src-tauri/src/commands/stream_check.rs index cc7e4d51..b7d5ea0b 100644 --- a/src-tauri/src/commands/stream_check.rs +++ b/src-tauri/src/commands/stream_check.rs @@ -26,6 +26,7 @@ pub async fn stream_check_provider( .ok_or_else(|| AppError::Message(format!("供应商 {provider_id} 不存在")))?; let auth_override = resolve_copilot_auth_override(provider, &copilot_state).await?; + let base_url_override = resolve_copilot_base_url_override(provider, &copilot_state).await?; let claude_api_format_override = resolve_claude_api_format_override( &app_type, provider, @@ -39,6 +40,7 @@ pub async fn stream_check_provider( provider, &config, auth_override, + base_url_override, claude_api_format_override, ) .await?; @@ -87,6 +89,8 @@ pub async fn stream_check_all_providers( } let auth_override = resolve_copilot_auth_override(&provider, &copilot_state).await?; + let base_url_override = + resolve_copilot_base_url_override(&provider, &copilot_state).await?; let claude_api_format_override = resolve_claude_api_format_override( &app_type, &provider, @@ -108,6 +112,7 @@ pub async fn stream_check_all_providers( &provider, &config, auth_override, + base_url_override, claude_api_format_override, ) .await @@ -151,17 +156,7 @@ async fn resolve_copilot_auth_override( provider: &crate::provider::Provider, copilot_state: &State<'_, CopilotAuthState>, ) -> Result, AppError> { - let is_copilot = provider - .meta - .as_ref() - .and_then(|meta| meta.provider_type.as_deref()) - == Some("github_copilot") - || provider - .settings_config - .pointer("/env/ANTHROPIC_BASE_URL") - .and_then(|value| value.as_str()) - .map(|url| url.contains("githubcopilot.com")) - .unwrap_or(false); + let is_copilot = is_copilot_provider(provider); if !is_copilot { return Ok(None); @@ -171,7 +166,7 @@ async fn resolve_copilot_auth_override( let account_id = provider .meta .as_ref() - .and_then(|meta| meta.github_account_id.clone()); + .and_then(|meta| meta.managed_account_id_for("github_copilot")); let token = match account_id.as_deref() { Some(id) => auth_manager @@ -190,6 +185,49 @@ async fn resolve_copilot_auth_override( ))) } +async fn resolve_copilot_base_url_override( + provider: &crate::provider::Provider, + copilot_state: &State<'_, CopilotAuthState>, +) -> Result, AppError> { + let is_copilot = is_copilot_provider(provider); + let is_full_url = provider + .meta + .as_ref() + .and_then(|meta| meta.is_full_url) + .unwrap_or(false); + + if !is_copilot || is_full_url { + return Ok(None); + } + + let auth_manager = copilot_state.0.read().await; + let account_id = provider + .meta + .as_ref() + .and_then(|meta| meta.managed_account_id_for("github_copilot")); + + let endpoint = match account_id.as_deref() { + Some(id) => auth_manager.get_api_endpoint(id).await, + None => auth_manager.get_default_api_endpoint().await, + }; + + Ok(Some(endpoint)) +} + +fn is_copilot_provider(provider: &crate::provider::Provider) -> bool { + provider + .meta + .as_ref() + .and_then(|meta| meta.provider_type.as_deref()) + == Some("github_copilot") + || provider + .settings_config + .pointer("/env/ANTHROPIC_BASE_URL") + .and_then(|value| value.as_str()) + .map(|url| url.contains("githubcopilot.com")) + .unwrap_or(false) +} + async fn resolve_claude_api_format_override( app_type: &AppType, provider: &crate::provider::Provider, @@ -237,3 +275,80 @@ async fn resolve_claude_api_format_override( Ok(Some(api_format.to_string())) } + +#[cfg(test)] +mod tests { + use super::is_copilot_provider; + use crate::provider::{Provider, ProviderMeta}; + use serde_json::json; + + #[test] + fn copilot_provider_detection_accepts_provider_type_or_base_url() { + let typed_provider = Provider { + id: "p1".to_string(), + name: "typed".to_string(), + settings_config: json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: Some(ProviderMeta { + provider_type: Some("github_copilot".to_string()), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + }; + assert!(is_copilot_provider(&typed_provider)); + + let url_provider = Provider { + id: "p2".to_string(), + name: "url".to_string(), + settings_config: json!({ + "env": { + "ANTHROPIC_BASE_URL": "https://api.githubcopilot.com" + } + }), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: None, + icon: None, + icon_color: None, + in_failover_queue: false, + }; + assert!(is_copilot_provider(&url_provider)); + } + + #[test] + fn copilot_full_url_metadata_is_available_for_override_guard() { + let provider = Provider { + id: "p3".to_string(), + name: "relay".to_string(), + settings_config: json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: Some(ProviderMeta { + provider_type: Some("github_copilot".to_string()), + is_full_url: Some(true), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + }; + + assert!(is_copilot_provider(&provider)); + assert_eq!( + provider.meta.as_ref().and_then(|meta| meta.is_full_url), + Some(true) + ); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index b9011834..1be714b4 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -872,6 +872,7 @@ pub fn run() { commands::restart_app, commands::check_for_updates, commands::is_portable_mode, + commands::copy_text_to_clipboard, commands::get_claude_plugin_status, commands::read_claude_plugin_config, commands::apply_claude_plugin_config, diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index f9dd145e..635a8bc3 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -747,7 +747,7 @@ impl RequestForwarder { adapter: &dyn ProviderAdapter, ) -> Result<(ProxyResponse, Option), ProxyError> { // 使用适配器提取 base_url - let base_url = adapter.extract_base_url(provider)?; + let mut base_url = adapter.extract_base_url(provider)?; let is_full_url = provider .meta @@ -770,6 +770,36 @@ impl RequestForwarder { .and_then(|m| m.provider_type.as_deref()) == Some("github_copilot") || base_url.contains("githubcopilot.com"); + + // GitHub Copilot 动态 endpoint 路由 + // 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint) + if is_copilot && !is_full_url { + if let Some(app_handle) = &self.app_handle { + let copilot_state = app_handle.state::(); + let copilot_auth = copilot_state.0.read().await; + + // 从 provider.meta 获取关联的 GitHub 账号 ID + let account_id = provider + .meta + .as_ref() + .and_then(|m| m.managed_account_id_for("github_copilot")); + + let dynamic_endpoint = match &account_id { + Some(id) => copilot_auth.get_api_endpoint(id).await, + None => copilot_auth.get_default_api_endpoint().await, + }; + + // 只在动态 endpoint 与当前 base_url 不同时替换 + if dynamic_endpoint != base_url { + log::debug!( + "[Copilot] 使用动态 API endpoint: {} (原: {})", + dynamic_endpoint, + base_url + ); + base_url = dynamic_endpoint; + } + } + } let resolved_claude_api_format = if adapter.name() == "Claude" { Some( self.resolve_claude_api_format(provider, &mapped_body, is_copilot) @@ -893,6 +923,12 @@ impl RequestForwarder { "copilot-integration-id", "x-github-api-version", "openai-intent", + // 新增 headers + "x-initiator", + "x-interaction-type", + "x-vscode-user-agent-library-version", + "x-request-id", + "x-agent-task-id", ] } else { &[] @@ -1663,4 +1699,107 @@ mod tests { &headers )); } + + // ==================== Copilot 动态 endpoint 路由相关测试 ==================== + + /// 验证 is_copilot 检测逻辑:通过 provider_type 判断 + #[test] + fn copilot_detection_via_provider_type() { + use crate::provider::{Provider, ProviderMeta}; + + let provider = Provider { + id: "test".to_string(), + name: "Test Copilot".to_string(), + settings_config: serde_json::json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: Some(ProviderMeta { + provider_type: Some("github_copilot".to_string()), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + }; + + let is_copilot = provider + .meta + .as_ref() + .and_then(|m| m.provider_type.as_deref()) + == Some("github_copilot"); + + assert!(is_copilot, "应该通过 provider_type 检测为 Copilot"); + } + + /// 验证 is_copilot 检测逻辑:通过 base_url 判断 + #[test] + fn copilot_detection_via_base_url() { + let base_url = "https://api.githubcopilot.com"; + let is_copilot = base_url.contains("githubcopilot.com"); + assert!(is_copilot, "应该通过 base_url 检测为 Copilot"); + + let non_copilot_url = "https://api.anthropic.com"; + let is_not_copilot = non_copilot_url.contains("githubcopilot.com"); + assert!(!is_not_copilot, "非 Copilot URL 不应被检测为 Copilot"); + } + + /// 验证企业版 endpoint(不包含 githubcopilot.com)场景下 is_copilot 仍然正确 + #[test] + fn copilot_detection_for_enterprise_endpoint() { + use crate::provider::{Provider, ProviderMeta}; + + // 企业版场景:provider_type 是 github_copilot,但 base_url 可能是企业内部域名 + let provider = Provider { + id: "enterprise".to_string(), + name: "Enterprise Copilot".to_string(), + settings_config: serde_json::json!({}), + website_url: None, + category: None, + created_at: None, + sort_index: None, + notes: None, + meta: Some(ProviderMeta { + provider_type: Some("github_copilot".to_string()), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + }; + + let enterprise_base_url = "https://copilot-api.corp.example.com"; + + // is_copilot 应该通过 provider_type 检测成功,即使 base_url 不包含 githubcopilot.com + let is_copilot = provider + .meta + .as_ref() + .and_then(|m| m.provider_type.as_deref()) + == Some("github_copilot") + || enterprise_base_url.contains("githubcopilot.com"); + + assert!( + is_copilot, + "企业版 Copilot 应该通过 provider_type 被正确检测" + ); + } + + /// 验证动态 endpoint 替换条件 + #[test] + fn dynamic_endpoint_replacement_conditions() { + // 条件:is_copilot && !is_full_url + let test_cases = [ + (true, false, true, "Copilot + 非 full_url 应该替换"), + (true, true, false, "Copilot + full_url 不应替换"), + (false, false, false, "非 Copilot 不应替换"), + (false, true, false, "非 Copilot + full_url 不应替换"), + ]; + + for (is_copilot, is_full_url, should_replace, desc) in test_cases { + let will_replace = is_copilot && !is_full_url; + assert_eq!(will_replace, should_replace, "{desc}"); + } + } } diff --git a/src-tauri/src/proxy/hyper_client.rs b/src-tauri/src/proxy/hyper_client.rs index ecd69ae6..f4533759 100644 --- a/src-tauri/src/proxy/hyper_client.rs +++ b/src-tauri/src/proxy/hyper_client.rs @@ -359,6 +359,10 @@ async fn send_raw_request( .write_all(&raw) .await .map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?; + tls_stream + .flush() + .await + .map_err(|e| ProxyError::ForwardFailed(format!("Flush failed: {e}")))?; let filtered = WriteFilter::new(tls_stream); do_hyper_response(filtered, method.clone()).await @@ -368,6 +372,10 @@ async fn send_raw_request( .write_all(&raw) .await .map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?; + stream + .flush() + .await + .map_err(|e| ProxyError::ForwardFailed(format!("Flush failed: {e}")))?; let filtered = WriteFilter::new(stream); do_hyper_response(filtered, method.clone()).await @@ -441,6 +449,10 @@ async fn connect_via_proxy( .write_all(connect_req.as_bytes()) .await .map_err(|e| ProxyError::ForwardFailed(format!("CONNECT write failed: {e}")))?; + stream + .flush() + .await + .map_err(|e| ProxyError::ForwardFailed(format!("CONNECT flush failed: {e}")))?; // Read the proxy's response status line let mut reader = BufReader::new(&mut stream); diff --git a/src-tauri/src/proxy/providers/claude.rs b/src-tauri/src/proxy/providers/claude.rs index 8d59dc2a..2cd7c017 100644 --- a/src-tauri/src/proxy/providers/claude.rs +++ b/src-tauri/src/proxy/providers/claude.rs @@ -348,6 +348,8 @@ impl ProviderAdapter for ClaudeAdapter { )] } AuthStrategy::GitHubCopilot => { + // 生成请求追踪 ID + let request_id = uuid::Uuid::new_v4().to_string(); vec![ ( HeaderName::from_static("authorization"), @@ -373,9 +375,30 @@ impl ProviderAdapter for ClaudeAdapter { HeaderName::from_static("x-github-api-version"), HeaderValue::from_static(super::copilot_auth::COPILOT_API_VERSION), ), + // 26-04-01新增的copilot关键 headers ( HeaderName::from_static("openai-intent"), - HeaderValue::from_static("conversation-panel"), + HeaderValue::from_static("conversation-agent"), + ), + ( + HeaderName::from_static("x-initiator"), + HeaderValue::from_static("user"), + ), + ( + HeaderName::from_static("x-interaction-type"), + HeaderValue::from_static("conversation-agent"), + ), + ( + HeaderName::from_static("x-vscode-user-agent-library-version"), + HeaderValue::from_static("electron-fetch"), + ), + ( + HeaderName::from_static("x-request-id"), + HeaderValue::from_str(&request_id).unwrap(), + ), + ( + HeaderName::from_static("x-agent-task-id"), + HeaderValue::from_str(&request_id).unwrap(), ), ] } diff --git a/src-tauri/src/proxy/providers/copilot_auth.rs b/src-tauri/src/proxy/providers/copilot_auth.rs index 04f3f9f3..44a23271 100644 --- a/src-tauri/src/proxy/providers/copilot_auth.rs +++ b/src-tauri/src/proxy/providers/copilot_auth.rs @@ -46,15 +46,18 @@ const TOKEN_REFRESH_BUFFER_SECONDS: i64 = 60; const COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models"; /// Copilot API Header 常量 -pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.96.0"; -pub const COPILOT_PLUGIN_VERSION: &str = "copilot-chat/0.26.7"; -pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7"; -pub const COPILOT_API_VERSION: &str = "2025-04-01"; +pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.110.1"; +pub const COPILOT_PLUGIN_VERSION: &str = "copilot-chat/0.38.2"; +pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.38.2"; +pub const COPILOT_API_VERSION: &str = "2025-10-01"; pub const COPILOT_INTEGRATION_ID: &str = "vscode-chat"; /// Copilot 使用量 API URL const COPILOT_USAGE_URL: &str = "https://api.github.com/copilot_internal/user"; +/// 默认 Copilot API 端点 +const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com"; + /// Copilot 使用量响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CopilotUsageResponse { @@ -64,6 +67,19 @@ pub struct CopilotUsageResponse { pub quota_reset_date: String, /// 配额快照 pub quota_snapshots: QuotaSnapshots, + /// API 端点信息 (用于动态获取 API URL) + #[serde(default)] + pub endpoints: Option, +} + +/// Copilot API 端点信息 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CopilotEndpoints { + /// API 端点 URL + pub api: String, + /// Telemetry 端点 URL + #[serde(default)] + pub telemetry: Option, } /// 配额快照 @@ -312,6 +328,10 @@ pub struct CopilotAuthManager { copilot_tokens: Arc>>, /// Copilot Models 缓存(key = GitHub user ID,仅进程内复用) copilot_models: Arc>>>, + /// Copilot API 端点缓存(key = GitHub user ID,从 /copilot_internal/user 获取) + api_endpoints: Arc>>, + /// 每个账号的端点拉取锁,避免并发拉取重复打 GitHub API + endpoint_locks: Arc>>>>, /// HTTP 客户端 http_client: Client, /// 存储路径 @@ -333,6 +353,8 @@ impl CopilotAuthManager { refresh_locks: Arc::new(RwLock::new(HashMap::new())), copilot_tokens: Arc::new(RwLock::new(HashMap::new())), copilot_models: Arc::new(RwLock::new(HashMap::new())), + api_endpoints: Arc::new(RwLock::new(HashMap::new())), + endpoint_locks: Arc::new(RwLock::new(HashMap::new())), http_client: Client::new(), storage_path, pending_migration: Arc::new(RwLock::new(None)), @@ -386,6 +408,15 @@ impl CopilotAuthManager { let mut refresh_locks = self.refresh_locks.write().await; refresh_locks.remove(account_id); } + // 清理 API 端点缓存 + { + let mut api_endpoints = self.api_endpoints.write().await; + api_endpoints.remove(account_id); + } + { + let mut endpoint_locks = self.endpoint_locks.write().await; + endpoint_locks.remove(account_id); + } { let accounts = self.accounts.read().await; @@ -775,6 +806,14 @@ impl CopilotAuthManager { .await .map_err(|e| CopilotAuthError::ParseError(e.to_string()))?; + // 存储动态 API 端点(如果有) + if let Some(ref endpoints) = usage.endpoints { + let mut api_endpoints = self.api_endpoints.write().await; + api_endpoints.insert(account_id.to_string(), endpoints.api.clone()); + // 使用 debug 级别避免在日志中暴露企业内部域名 + log::debug!("[CopilotAuth] 账号 {account_id} 已保存动态 API 端点"); + } + log::info!( "[CopilotAuth] 获取使用量成功,计划: {}, 重置日期: {}", usage.copilot_plan, @@ -794,6 +833,118 @@ impl CopilotAuthManager { // ==================== 状态查询 ==================== + /// 获取指定账号的 API 端点(缓存命中直接返回,未命中则从 API 惰性拉取) + pub async fn get_api_endpoint(&self, account_id: &str) -> String { + let _ = self.ensure_migration_complete().await; + + { + let endpoints = self.api_endpoints.read().await; + if let Some(endpoint) = endpoints.get(account_id) { + return endpoint.clone(); + } + } + + // 用锁串行化同一账号的并发拉取,避免对 GitHub API 的重复请求 + let lock = self.get_endpoint_lock(account_id).await; + let _guard = lock.lock().await; + + // 持锁后二次检查:可能已由其他请求填充 + { + let endpoints = self.api_endpoints.read().await; + if let Some(endpoint) = endpoints.get(account_id) { + return endpoint.clone(); + } + } + + match self.fetch_and_cache_endpoint(account_id).await { + Ok(endpoint) => endpoint, + Err(e) => { + log::debug!( + "[CopilotAuth] 获取账号 {account_id} 动态 API 端点失败: {e},使用默认值" + ); + DEFAULT_COPILOT_API_ENDPOINT.to_string() + } + } + } + + /// 获取默认账号的 API 端点 + pub async fn get_default_api_endpoint(&self) -> String { + let _ = self.ensure_migration_complete().await; + + match self.resolve_default_account_id().await { + Some(id) => self.get_api_endpoint(&id).await, + None => DEFAULT_COPILOT_API_ENDPOINT.to_string(), + } + } + + async fn fetch_and_cache_endpoint(&self, account_id: &str) -> Result { + let github_token = { + let accounts = self.accounts.read().await; + accounts + .get(account_id) + .map(|a| a.github_token.clone()) + .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))? + }; + + log::debug!("[CopilotAuth] 为账号 {account_id} 惰性拉取动态 API 端点"); + + let response = self + .http_client + .get(COPILOT_USAGE_URL) + .header("Authorization", format!("token {github_token}")) + .header("Content-Type", "application/json") + .header("editor-version", COPILOT_EDITOR_VERSION) + .header("editor-plugin-version", COPILOT_PLUGIN_VERSION) + .header("user-agent", COPILOT_USER_AGENT) + .header("x-github-api-version", COPILOT_API_VERSION) + .send() + .await?; + + if response.status() == reqwest::StatusCode::UNAUTHORIZED { + return Err(CopilotAuthError::GitHubTokenInvalid); + } + + if !response.status().is_success() { + return Err(CopilotAuthError::CopilotTokenFetchFailed(format!( + "获取 API 端点失败: {}", + response.status() + ))); + } + + let usage: CopilotUsageResponse = response + .json() + .await + .map_err(|e| CopilotAuthError::ParseError(e.to_string()))?; + + let endpoint = match usage.endpoints { + Some(endpoints) => endpoints.api.clone(), + None => DEFAULT_COPILOT_API_ENDPOINT.to_string(), + }; + + // 缓存端点(包括默认值),避免重复请求 + let mut api_endpoints = self.api_endpoints.write().await; + api_endpoints.insert(account_id.to_string(), endpoint.clone()); + log::debug!("[CopilotAuth] 账号 {account_id} 已缓存 API 端点"); + + Ok(endpoint) + } + + async fn get_endpoint_lock(&self, account_id: &str) -> Arc> { + { + let locks = self.endpoint_locks.read().await; + if let Some(lock) = locks.get(account_id) { + return Arc::clone(lock); + } + } + + let mut locks = self.endpoint_locks.write().await; + Arc::clone( + locks + .entry(account_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))), + ) + } + /// 获取认证状态(支持多账号) pub async fn get_status(&self) -> CopilotAuthStatus { // 确保迁移完成 @@ -838,6 +989,7 @@ impl CopilotAuthManager { pub async fn clear_auth(&self) -> Result<(), CopilotAuthError> { log::info!("[CopilotAuth] 清除所有认证"); + // 先清理内存状态,确保即使文件删除失败用户也能看到已登出 { let mut accounts = self.accounts.write().await; accounts.clear(); @@ -851,12 +1003,25 @@ impl CopilotAuthManager { let mut tokens = self.copilot_tokens.write().await; tokens.clear(); } + { + let mut models = self.copilot_models.write().await; + models.clear(); + } { let mut refresh_locks = self.refresh_locks.write().await; refresh_locks.clear(); } + // 清理 API 端点缓存 + { + let mut api_endpoints = self.api_endpoints.write().await; + api_endpoints.clear(); + } + { + let mut endpoint_locks = self.endpoint_locks.write().await; + endpoint_locks.clear(); + } - // 删除存储文件 + // 最后删除存储文件 if self.storage_path.exists() { std::fs::remove_file(&self.storage_path)?; } @@ -1414,4 +1579,242 @@ mod tests { let default_vendor = manager.get_model_vendor("claude-sonnet-4").await.unwrap(); assert_eq!(default_vendor.as_deref(), Some("Anthropic")); } + + #[tokio::test] + async fn test_get_api_endpoint_returns_cached_value() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + // 手动设置 api_endpoints 缓存 + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert( + "12345".to_string(), + "https://copilot-api.enterprise.example.com".to_string(), + ); + } + + let endpoint = manager.get_api_endpoint("12345").await; + assert_eq!(endpoint, "https://copilot-api.enterprise.example.com"); + } + + #[tokio::test] + async fn test_get_api_endpoint_returns_default_when_not_cached() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let endpoint = manager.get_api_endpoint("99999").await; + assert_eq!(endpoint, "https://api.githubcopilot.com"); + } + + #[tokio::test] + async fn test_get_default_api_endpoint_uses_default_account() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + // 设置默认账号 + { + let mut default_account_id = manager.default_account_id.write().await; + *default_account_id = Some("12345".to_string()); + } + // 添加账号数据 + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + GitHubAccountData { + github_token: "gho_test".to_string(), + user: GitHubUser { + login: "alice".to_string(), + id: 12345, + avatar_url: None, + }, + authenticated_at: 1700000000, + }, + ); + } + // 设置 API endpoint 缓存 + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert( + "12345".to_string(), + "https://copilot-api.corp.example.com".to_string(), + ); + } + + let endpoint = manager.get_default_api_endpoint().await; + assert_eq!(endpoint, "https://copilot-api.corp.example.com"); + } + + #[tokio::test] + async fn test_remove_account_clears_api_endpoint_cache() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + // 添加账号数据 + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + GitHubAccountData { + github_token: "gho_test".to_string(), + user: GitHubUser { + login: "alice".to_string(), + id: 12345, + avatar_url: None, + }, + authenticated_at: 1700000000, + }, + ); + } + // 设置 API endpoint 缓存 + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert( + "12345".to_string(), + "https://copilot-api.enterprise.example.com".to_string(), + ); + } + + // 确认缓存存在 + { + let api_endpoints = manager.api_endpoints.read().await; + assert!(api_endpoints.contains_key("12345")); + } + + // 移除账号 + manager.remove_account("12345").await.unwrap(); + + // 确认缓存已清理 + { + let api_endpoints = manager.api_endpoints.read().await; + assert!(!api_endpoints.contains_key("12345")); + } + } + + #[tokio::test] + async fn test_clear_auth_clears_all_api_endpoint_cache() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + // 添加多个账号的 API endpoint 缓存 + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert( + "12345".to_string(), + "https://copilot-api.enterprise1.example.com".to_string(), + ); + api_endpoints.insert( + "67890".to_string(), + "https://copilot-api.enterprise2.example.com".to_string(), + ); + } + + // 确认缓存存在 + { + let api_endpoints = manager.api_endpoints.read().await; + assert_eq!(api_endpoints.len(), 2); + } + + // 清除所有认证 + manager.clear_auth().await.unwrap(); + + // 确认缓存已清空 + { + let api_endpoints = manager.api_endpoints.read().await; + assert!(api_endpoints.is_empty()); + } + } + + #[tokio::test] + async fn test_clear_auth_cleans_memory_even_when_file_removal_fails() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + // Create a directory at storage_path so remove_file fails + std::fs::create_dir_all(&manager.storage_path).unwrap(); + + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + GitHubAccountData { + github_token: "gho_test".to_string(), + user: GitHubUser { + login: "alice".to_string(), + id: 12345, + avatar_url: None, + }, + authenticated_at: 1700000000, + }, + ); + } + { + let mut default_account_id = manager.default_account_id.write().await; + *default_account_id = Some("12345".to_string()); + } + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert( + "12345".to_string(), + "https://copilot-api.enterprise.example.com".to_string(), + ); + } + + let result = manager.clear_auth().await; + // Should still return an error for the file deletion failure + assert!(result.is_err()); + + // But memory state should already be cleaned + let accounts = manager.accounts.read().await; + assert!(accounts.is_empty()); + drop(accounts); + + let default_account_id = manager.default_account_id.read().await; + assert!(default_account_id.is_none()); + drop(default_account_id); + + let api_endpoints = manager.api_endpoints.read().await; + assert!(api_endpoints.is_empty()); + } + + #[tokio::test] + async fn test_get_api_endpoint_cache_hit_skips_fetch() { + // 缓存命中时应直接返回,不发起网络请求 + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let enterprise_endpoint = "https://copilot-api.enterprise.example.com".to_string(); + { + let mut api_endpoints = manager.api_endpoints.write().await; + api_endpoints.insert("12345".to_string(), enterprise_endpoint.clone()); + } + + // 即使没有账号数据,缓存命中也应直接返回 + let endpoint = manager.get_api_endpoint("12345").await; + assert_eq!(endpoint, enterprise_endpoint); + } + + #[tokio::test] + async fn test_get_api_endpoint_returns_default_for_unknown_account() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let endpoint = manager.get_api_endpoint("12345").await; + assert_eq!(endpoint, DEFAULT_COPILOT_API_ENDPOINT); + } + + #[tokio::test] + async fn test_fetch_and_cache_endpoint_requires_account() { + // 账号不存在时 fetch_and_cache_endpoint 应返回 AccountNotFound 错误 + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let result = manager.fetch_and_cache_endpoint("nonexistent").await; + assert!(result.is_err()); + match result.unwrap_err() { + CopilotAuthError::AccountNotFound(id) => assert_eq!(id, "nonexistent"), + other => panic!("期望 AccountNotFound 错误,实际: {other:?}"), + } + } } diff --git a/src-tauri/src/services/stream_check.rs b/src-tauri/src/services/stream_check.rs index 39e21d05..79f1021e 100644 --- a/src-tauri/src/services/stream_check.rs +++ b/src-tauri/src/services/stream_check.rs @@ -88,6 +88,7 @@ impl StreamCheckService { provider: &Provider, config: &StreamCheckConfig, auth_override: Option, + base_url_override: Option, claude_api_format_override: Option, ) -> Result { // 合并供应商单独配置和全局配置 @@ -100,6 +101,7 @@ impl StreamCheckService { provider, &effective_config, auth_override.clone(), + base_url_override.clone(), claude_api_format_override.clone(), ) .await; @@ -191,14 +193,18 @@ impl StreamCheckService { provider: &Provider, config: &StreamCheckConfig, auth_override: Option, + base_url_override: Option, claude_api_format_override: Option, ) -> Result { let start = Instant::now(); let adapter = get_adapter(app_type); - let base_url = adapter - .extract_base_url(provider) - .map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?; + let base_url = match base_url_override { + Some(base_url) => base_url, + None => adapter + .extract_base_url(provider) + .map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?, + }; let auth = auth_override .or_else(|| adapter.extract_auth(provider)) @@ -364,6 +370,8 @@ impl StreamCheckService { let mut request_builder = client.post(&url); if is_github_copilot { + // 生成请求追踪 ID + let request_id = uuid::Uuid::new_v4().to_string(); request_builder = request_builder .header("authorization", format!("Bearer {}", auth.api_key)) .header("content-type", "application/json") @@ -380,7 +388,13 @@ impl StreamCheckService { copilot_auth::COPILOT_INTEGRATION_ID, ) .header("x-github-api-version", copilot_auth::COPILOT_API_VERSION) - .header("openai-intent", "conversation-panel"); + // 260401 新增copilot 的关键 headers + .header("openai-intent", "conversation-agent") + .header("x-initiator", "user") + .header("x-interaction-type", "conversation-agent") + .header("x-vscode-user-agent-library-version", "electron-fetch") + .header("x-request-id", &request_id) + .header("x-agent-task-id", &request_id); } else if is_openai_chat || is_openai_responses { // OpenAI-compatible targets: Bearer auth + SSE headers only request_builder = request_builder diff --git a/src/components/providers/forms/CopilotAuthSection.tsx b/src/components/providers/forms/CopilotAuthSection.tsx index b79d07ea..608eb6d0 100644 --- a/src/components/providers/forms/CopilotAuthSection.tsx +++ b/src/components/providers/forms/CopilotAuthSection.tsx @@ -22,6 +22,7 @@ import { User, } from "lucide-react"; import { useCopilotAuth } from "./hooks/useCopilotAuth"; +import { copyText } from "@/lib/clipboard"; import type { GitHubAccount } from "@/lib/api"; interface CopilotAuthSectionProps { @@ -67,7 +68,7 @@ export const CopilotAuthSection: React.FC = ({ // 复制用户码 const copyUserCode = async () => { if (deviceCode?.user_code) { - await navigator.clipboard.writeText(deviceCode.user_code); + await copyText(deviceCode.user_code); setCopied(true); setTimeout(() => setCopied(false), 2000); } diff --git a/src/components/providers/forms/hooks/useManagedAuth.ts b/src/components/providers/forms/hooks/useManagedAuth.ts index 787f1094..138f3f72 100644 --- a/src/components/providers/forms/hooks/useManagedAuth.ts +++ b/src/components/providers/forms/hooks/useManagedAuth.ts @@ -1,6 +1,7 @@ import { useState, useCallback, useRef, useEffect } from "react"; import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; import { authApi, settingsApi } from "@/lib/api"; +import { copyText } from "@/lib/clipboard"; import type { ManagedAuthProvider, ManagedAuthStatus, @@ -58,7 +59,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) { setError(null); try { - await navigator.clipboard.writeText(response.user_code); + await copyText(response.user_code); } catch (e) { console.debug("[ManagedAuth] Failed to copy user code:", e); } diff --git a/src/lib/clipboard.ts b/src/lib/clipboard.ts new file mode 100644 index 00000000..ddbb3129 --- /dev/null +++ b/src/lib/clipboard.ts @@ -0,0 +1,19 @@ +import { invoke } from "@tauri-apps/api/core"; + +export async function copyText(text: string): Promise { + try { + await invoke("copy_text_to_clipboard", { text }); + return; + } catch (nativeError) { + try { + await navigator.clipboard.writeText(text); + return; + } catch (webError) { + throw webError instanceof Error + ? webError + : nativeError instanceof Error + ? nativeError + : new Error(String(webError || nativeError)); + } + } +}