mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-05-07 06:07:18 +08:00
fix(copilot): 修复 GitHub Copilot 认证和代理问题 (#1854)
* fix(copilot): 修复 GitHub Copilot 400 认证错误 问题:使用 GitHub Copilot provider 时报错 400 bad request 根因:与 copilot-api 项目对比发现多处差异 修复内容: - 更新版本号 0.26.7 到 0.38.2 - 更新 API 版本 2025-04-01 到 2025-10-01 - 添加缺失的关键 headers - 修正 openai-intent 值 - 添加动态 API endpoint 支持 - 同步更新 stream_check.rs headers Closes #1777 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: flush stream after write_all in hyper_client proxy Add explicit flush() calls after write_all() for TLS stream, plain TCP stream, and CONNECT tunnel requests to ensure buffered data is sent immediately, preventing connection hangs in Copilot auth header flow. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * 修复登录时的剪切板在mac与linux端可能没复制验证码 * fix: flush stream after write_all in hyper_client proxy Add explicit flush() calls after write_all() for TLS stream, plain TCP stream, and CONNECT tunnel requests to ensure buffered data is sent immediately, preventing connection hangs in Copilot auth header flow. * 修复登录时的剪切板在mac与linux端可能没复制验证码 * 1、修复不同类型的个人商业等不同类型的copilot账号问题 2、将验证码复制改为异步操作 * fix: address PR review comments for Copilot auth │ │ │ │ - Fix clipboard blocking by using spawn_blocking for arboard ops │ │ - Implement dynamic endpoint routing for enterprise Copilot users │ │ - Add api_endpoints cache cleanup in remove_account() and clear_auth() │ │ - Change API endpoint log level from info to debug │ │ - Fix clear_auth() to continue cleanup even if file deletion fails │ │ - Add 9 unit tests for Copilot detection and api_endpoints cachin * style: fix cargo fmt formatting * Fix Copilot dynamic endpoint handling * fix: restore clear_auth() memory-first cleanup order and fix cache leaks - Restore clear_auth() to clean memory state before deleting the storage file. The previous order (file deletion first) caused a regression where users could get stuck in a "cannot log out" state if file removal failed. - Add missing copilot_models.clear() in clear_auth() — this cache was cleaned in remove_account() but never in the full clear path. - Add endpoint_locks cleanup in both remove_account() and clear_auth() to prevent minor in-process memory leaks. - Update test to assert the correct behavior: memory should be cleaned even when file deletion fails. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: 周梦泽 <mengze.zhou@dafeng-tech.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Jason <farion1231@gmail.com>
This commit is contained in:
137
src-tauri/Cargo.lock
generated
137
src-tauri/Cargo.lock
generated
@@ -107,6 +107,26 @@ dependencies = [
|
||||
"derive_arbitrary",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arboard"
|
||||
version = "3.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0348a1c054491f4bfe6ab86a7b6ab1e44e45d899005de92f58b3df180b36ddaf"
|
||||
dependencies = [
|
||||
"clipboard-win",
|
||||
"image",
|
||||
"log",
|
||||
"objc2 0.6.4",
|
||||
"objc2-app-kit 0.3.2",
|
||||
"objc2-core-foundation",
|
||||
"objc2-core-graphics",
|
||||
"objc2-foundation 0.3.2",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"windows-sys 0.60.2",
|
||||
"x11rb",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
@@ -718,6 +738,7 @@ name = "cc-switch"
|
||||
version = "3.12.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arboard",
|
||||
"async-stream",
|
||||
"auto-launch",
|
||||
"axum",
|
||||
@@ -843,6 +864,15 @@ dependencies = [
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clipboard-win"
|
||||
version = "5.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4"
|
||||
dependencies = [
|
||||
"error-code",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.57"
|
||||
@@ -1445,6 +1475,12 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "error-code"
|
||||
version = "3.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59"
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.4.1"
|
||||
@@ -1484,6 +1520,26 @@ version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "fax"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f05de7d48f37cd6730705cbca900770cab77a89f413d23e100ad7fad7795a0ab"
|
||||
dependencies = [
|
||||
"fax_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fax_derive"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.7"
|
||||
@@ -1849,6 +1905,16 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gethostname"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"windows-link 0.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
@@ -2067,6 +2133,17 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
@@ -2457,6 +2534,7 @@ dependencies = [
|
||||
"moxcms",
|
||||
"num-traits",
|
||||
"png 0.18.1",
|
||||
"tiff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3088,6 +3166,7 @@ dependencies = [
|
||||
"block2 0.6.2",
|
||||
"objc2 0.6.4",
|
||||
"objc2-core-foundation",
|
||||
"objc2-core-graphics",
|
||||
"objc2-foundation 0.3.2",
|
||||
]
|
||||
|
||||
@@ -3894,6 +3973,12 @@ version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5a041e753da8b807c9255f28de81879c78c876392ff2469cde94799b2896b9d"
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||
|
||||
[[package]]
|
||||
name = "quick-xml"
|
||||
version = "0.38.4"
|
||||
@@ -5766,6 +5851,20 @@ dependencies = [
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b63feaf3343d35b6ca4d50483f94843803b0f51634937cc2ec519fc32232bc52"
|
||||
dependencies = [
|
||||
"fax",
|
||||
"flate2",
|
||||
"half",
|
||||
"quick-error",
|
||||
"weezl",
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.47"
|
||||
@@ -6660,6 +6759,12 @@ dependencies = [
|
||||
"windows-core 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@@ -7397,6 +7502,23 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9993aa5be5a26815fe2c3eacfc1fde061fc1a1f094bf1ad2a18bf9c495dd7414"
|
||||
dependencies = [
|
||||
"gethostname",
|
||||
"rustix",
|
||||
"x11rb-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x11rb-protocol"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd"
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.6.1"
|
||||
@@ -7682,6 +7804,21 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zvariant"
|
||||
version = "5.10.0"
|
||||
|
||||
@@ -39,6 +39,7 @@ dirs = "5.0"
|
||||
toml = "0.8"
|
||||
toml_edit = "0.22"
|
||||
reqwest = { version = "0.12", features = ["rustls-tls", "json", "stream", "socks"] }
|
||||
arboard = "3.6"
|
||||
flate2 = "1"
|
||||
brotli = "7"
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread", "time", "sync"] }
|
||||
|
||||
@@ -34,6 +34,22 @@ pub async fn open_external(app: AppHandle, url: String) -> Result<bool, String>
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn copy_text_to_clipboard(text: String) -> Result<bool, String> {
|
||||
// Use spawn_blocking to avoid blocking the async runtime
|
||||
// Clipboard access can block on some platforms and may have thread/loop constraints
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut clipboard =
|
||||
arboard::Clipboard::new().map_err(|e| format!("访问系统剪贴板失败: {e}"))?;
|
||||
clipboard
|
||||
.set_text(text)
|
||||
.map_err(|e| format!("写入系统剪贴板失败: {e}"))?;
|
||||
Ok(true)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| format!("剪贴板任务执行失败: {e}"))?
|
||||
}
|
||||
|
||||
/// 检查更新
|
||||
#[tauri::command]
|
||||
pub async fn check_for_updates(handle: AppHandle) -> Result<bool, String> {
|
||||
|
||||
@@ -26,6 +26,7 @@ pub async fn stream_check_provider(
|
||||
.ok_or_else(|| AppError::Message(format!("供应商 {provider_id} 不存在")))?;
|
||||
|
||||
let auth_override = resolve_copilot_auth_override(provider, &copilot_state).await?;
|
||||
let base_url_override = resolve_copilot_base_url_override(provider, &copilot_state).await?;
|
||||
let claude_api_format_override = resolve_claude_api_format_override(
|
||||
&app_type,
|
||||
provider,
|
||||
@@ -39,6 +40,7 @@ pub async fn stream_check_provider(
|
||||
provider,
|
||||
&config,
|
||||
auth_override,
|
||||
base_url_override,
|
||||
claude_api_format_override,
|
||||
)
|
||||
.await?;
|
||||
@@ -87,6 +89,8 @@ pub async fn stream_check_all_providers(
|
||||
}
|
||||
|
||||
let auth_override = resolve_copilot_auth_override(&provider, &copilot_state).await?;
|
||||
let base_url_override =
|
||||
resolve_copilot_base_url_override(&provider, &copilot_state).await?;
|
||||
let claude_api_format_override = resolve_claude_api_format_override(
|
||||
&app_type,
|
||||
&provider,
|
||||
@@ -108,6 +112,7 @@ pub async fn stream_check_all_providers(
|
||||
&provider,
|
||||
&config,
|
||||
auth_override,
|
||||
base_url_override,
|
||||
claude_api_format_override,
|
||||
)
|
||||
.await
|
||||
@@ -151,17 +156,7 @@ async fn resolve_copilot_auth_override(
|
||||
provider: &crate::provider::Provider,
|
||||
copilot_state: &State<'_, CopilotAuthState>,
|
||||
) -> Result<Option<crate::proxy::providers::AuthInfo>, AppError> {
|
||||
let is_copilot = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_BASE_URL")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(|url| url.contains("githubcopilot.com"))
|
||||
.unwrap_or(false);
|
||||
let is_copilot = is_copilot_provider(provider);
|
||||
|
||||
if !is_copilot {
|
||||
return Ok(None);
|
||||
@@ -171,7 +166,7 @@ async fn resolve_copilot_auth_override(
|
||||
let account_id = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.github_account_id.clone());
|
||||
.and_then(|meta| meta.managed_account_id_for("github_copilot"));
|
||||
|
||||
let token = match account_id.as_deref() {
|
||||
Some(id) => auth_manager
|
||||
@@ -190,6 +185,49 @@ async fn resolve_copilot_auth_override(
|
||||
)))
|
||||
}
|
||||
|
||||
async fn resolve_copilot_base_url_override(
|
||||
provider: &crate::provider::Provider,
|
||||
copilot_state: &State<'_, CopilotAuthState>,
|
||||
) -> Result<Option<String>, AppError> {
|
||||
let is_copilot = is_copilot_provider(provider);
|
||||
let is_full_url = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.is_full_url)
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_copilot || is_full_url {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let auth_manager = copilot_state.0.read().await;
|
||||
let account_id = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.managed_account_id_for("github_copilot"));
|
||||
|
||||
let endpoint = match account_id.as_deref() {
|
||||
Some(id) => auth_manager.get_api_endpoint(id).await,
|
||||
None => auth_manager.get_default_api_endpoint().await,
|
||||
};
|
||||
|
||||
Ok(Some(endpoint))
|
||||
}
|
||||
|
||||
fn is_copilot_provider(provider: &crate::provider::Provider) -> bool {
|
||||
provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_BASE_URL")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(|url| url.contains("githubcopilot.com"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn resolve_claude_api_format_override(
|
||||
app_type: &AppType,
|
||||
provider: &crate::provider::Provider,
|
||||
@@ -237,3 +275,80 @@ async fn resolve_claude_api_format_override(
|
||||
|
||||
Ok(Some(api_format.to_string()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_copilot_provider;
|
||||
use crate::provider::{Provider, ProviderMeta};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn copilot_provider_detection_accepts_provider_type_or_base_url() {
|
||||
let typed_provider = Provider {
|
||||
id: "p1".to_string(),
|
||||
name: "typed".to_string(),
|
||||
settings_config: json!({}),
|
||||
website_url: None,
|
||||
category: None,
|
||||
created_at: None,
|
||||
sort_index: None,
|
||||
notes: None,
|
||||
meta: Some(ProviderMeta {
|
||||
provider_type: Some("github_copilot".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
icon: None,
|
||||
icon_color: None,
|
||||
in_failover_queue: false,
|
||||
};
|
||||
assert!(is_copilot_provider(&typed_provider));
|
||||
|
||||
let url_provider = Provider {
|
||||
id: "p2".to_string(),
|
||||
name: "url".to_string(),
|
||||
settings_config: json!({
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
|
||||
}
|
||||
}),
|
||||
website_url: None,
|
||||
category: None,
|
||||
created_at: None,
|
||||
sort_index: None,
|
||||
notes: None,
|
||||
meta: None,
|
||||
icon: None,
|
||||
icon_color: None,
|
||||
in_failover_queue: false,
|
||||
};
|
||||
assert!(is_copilot_provider(&url_provider));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_full_url_metadata_is_available_for_override_guard() {
|
||||
let provider = Provider {
|
||||
id: "p3".to_string(),
|
||||
name: "relay".to_string(),
|
||||
settings_config: json!({}),
|
||||
website_url: None,
|
||||
category: None,
|
||||
created_at: None,
|
||||
sort_index: None,
|
||||
notes: None,
|
||||
meta: Some(ProviderMeta {
|
||||
provider_type: Some("github_copilot".to_string()),
|
||||
is_full_url: Some(true),
|
||||
..Default::default()
|
||||
}),
|
||||
icon: None,
|
||||
icon_color: None,
|
||||
in_failover_queue: false,
|
||||
};
|
||||
|
||||
assert!(is_copilot_provider(&provider));
|
||||
assert_eq!(
|
||||
provider.meta.as_ref().and_then(|meta| meta.is_full_url),
|
||||
Some(true)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -872,6 +872,7 @@ pub fn run() {
|
||||
commands::restart_app,
|
||||
commands::check_for_updates,
|
||||
commands::is_portable_mode,
|
||||
commands::copy_text_to_clipboard,
|
||||
commands::get_claude_plugin_status,
|
||||
commands::read_claude_plugin_config,
|
||||
commands::apply_claude_plugin_config,
|
||||
|
||||
@@ -747,7 +747,7 @@ impl RequestForwarder {
|
||||
adapter: &dyn ProviderAdapter,
|
||||
) -> Result<(ProxyResponse, Option<String>), ProxyError> {
|
||||
// 使用适配器提取 base_url
|
||||
let base_url = adapter.extract_base_url(provider)?;
|
||||
let mut base_url = adapter.extract_base_url(provider)?;
|
||||
|
||||
let is_full_url = provider
|
||||
.meta
|
||||
@@ -770,6 +770,36 @@ impl RequestForwarder {
|
||||
.and_then(|m| m.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| base_url.contains("githubcopilot.com");
|
||||
|
||||
// GitHub Copilot 动态 endpoint 路由
|
||||
// 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint)
|
||||
if is_copilot && !is_full_url {
|
||||
if let Some(app_handle) = &self.app_handle {
|
||||
let copilot_state = app_handle.state::<CopilotAuthState>();
|
||||
let copilot_auth = copilot_state.0.read().await;
|
||||
|
||||
// 从 provider.meta 获取关联的 GitHub 账号 ID
|
||||
let account_id = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|m| m.managed_account_id_for("github_copilot"));
|
||||
|
||||
let dynamic_endpoint = match &account_id {
|
||||
Some(id) => copilot_auth.get_api_endpoint(id).await,
|
||||
None => copilot_auth.get_default_api_endpoint().await,
|
||||
};
|
||||
|
||||
// 只在动态 endpoint 与当前 base_url 不同时替换
|
||||
if dynamic_endpoint != base_url {
|
||||
log::debug!(
|
||||
"[Copilot] 使用动态 API endpoint: {} (原: {})",
|
||||
dynamic_endpoint,
|
||||
base_url
|
||||
);
|
||||
base_url = dynamic_endpoint;
|
||||
}
|
||||
}
|
||||
}
|
||||
let resolved_claude_api_format = if adapter.name() == "Claude" {
|
||||
Some(
|
||||
self.resolve_claude_api_format(provider, &mapped_body, is_copilot)
|
||||
@@ -893,6 +923,12 @@ impl RequestForwarder {
|
||||
"copilot-integration-id",
|
||||
"x-github-api-version",
|
||||
"openai-intent",
|
||||
// 新增 headers
|
||||
"x-initiator",
|
||||
"x-interaction-type",
|
||||
"x-vscode-user-agent-library-version",
|
||||
"x-request-id",
|
||||
"x-agent-task-id",
|
||||
]
|
||||
} else {
|
||||
&[]
|
||||
@@ -1663,4 +1699,107 @@ mod tests {
|
||||
&headers
|
||||
));
|
||||
}
|
||||
|
||||
// ==================== Copilot 动态 endpoint 路由相关测试 ====================
|
||||
|
||||
/// 验证 is_copilot 检测逻辑:通过 provider_type 判断
|
||||
#[test]
|
||||
fn copilot_detection_via_provider_type() {
|
||||
use crate::provider::{Provider, ProviderMeta};
|
||||
|
||||
let provider = Provider {
|
||||
id: "test".to_string(),
|
||||
name: "Test Copilot".to_string(),
|
||||
settings_config: serde_json::json!({}),
|
||||
website_url: None,
|
||||
category: None,
|
||||
created_at: None,
|
||||
sort_index: None,
|
||||
notes: None,
|
||||
meta: Some(ProviderMeta {
|
||||
provider_type: Some("github_copilot".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
icon: None,
|
||||
icon_color: None,
|
||||
in_failover_queue: false,
|
||||
};
|
||||
|
||||
let is_copilot = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|m| m.provider_type.as_deref())
|
||||
== Some("github_copilot");
|
||||
|
||||
assert!(is_copilot, "应该通过 provider_type 检测为 Copilot");
|
||||
}
|
||||
|
||||
/// 验证 is_copilot 检测逻辑:通过 base_url 判断
|
||||
#[test]
|
||||
fn copilot_detection_via_base_url() {
|
||||
let base_url = "https://api.githubcopilot.com";
|
||||
let is_copilot = base_url.contains("githubcopilot.com");
|
||||
assert!(is_copilot, "应该通过 base_url 检测为 Copilot");
|
||||
|
||||
let non_copilot_url = "https://api.anthropic.com";
|
||||
let is_not_copilot = non_copilot_url.contains("githubcopilot.com");
|
||||
assert!(!is_not_copilot, "非 Copilot URL 不应被检测为 Copilot");
|
||||
}
|
||||
|
||||
/// 验证企业版 endpoint(不包含 githubcopilot.com)场景下 is_copilot 仍然正确
|
||||
#[test]
|
||||
fn copilot_detection_for_enterprise_endpoint() {
|
||||
use crate::provider::{Provider, ProviderMeta};
|
||||
|
||||
// 企业版场景:provider_type 是 github_copilot,但 base_url 可能是企业内部域名
|
||||
let provider = Provider {
|
||||
id: "enterprise".to_string(),
|
||||
name: "Enterprise Copilot".to_string(),
|
||||
settings_config: serde_json::json!({}),
|
||||
website_url: None,
|
||||
category: None,
|
||||
created_at: None,
|
||||
sort_index: None,
|
||||
notes: None,
|
||||
meta: Some(ProviderMeta {
|
||||
provider_type: Some("github_copilot".to_string()),
|
||||
..Default::default()
|
||||
}),
|
||||
icon: None,
|
||||
icon_color: None,
|
||||
in_failover_queue: false,
|
||||
};
|
||||
|
||||
let enterprise_base_url = "https://copilot-api.corp.example.com";
|
||||
|
||||
// is_copilot 应该通过 provider_type 检测成功,即使 base_url 不包含 githubcopilot.com
|
||||
let is_copilot = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|m| m.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| enterprise_base_url.contains("githubcopilot.com");
|
||||
|
||||
assert!(
|
||||
is_copilot,
|
||||
"企业版 Copilot 应该通过 provider_type 被正确检测"
|
||||
);
|
||||
}
|
||||
|
||||
/// 验证动态 endpoint 替换条件
|
||||
#[test]
|
||||
fn dynamic_endpoint_replacement_conditions() {
|
||||
// 条件:is_copilot && !is_full_url
|
||||
let test_cases = [
|
||||
(true, false, true, "Copilot + 非 full_url 应该替换"),
|
||||
(true, true, false, "Copilot + full_url 不应替换"),
|
||||
(false, false, false, "非 Copilot 不应替换"),
|
||||
(false, true, false, "非 Copilot + full_url 不应替换"),
|
||||
];
|
||||
|
||||
for (is_copilot, is_full_url, should_replace, desc) in test_cases {
|
||||
let will_replace = is_copilot && !is_full_url;
|
||||
assert_eq!(will_replace, should_replace, "{desc}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,6 +359,10 @@ async fn send_raw_request(
|
||||
.write_all(&raw)
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?;
|
||||
tls_stream
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("Flush failed: {e}")))?;
|
||||
|
||||
let filtered = WriteFilter::new(tls_stream);
|
||||
do_hyper_response(filtered, method.clone()).await
|
||||
@@ -368,6 +372,10 @@ async fn send_raw_request(
|
||||
.write_all(&raw)
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("Write failed: {e}")))?;
|
||||
stream
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("Flush failed: {e}")))?;
|
||||
|
||||
let filtered = WriteFilter::new(stream);
|
||||
do_hyper_response(filtered, method.clone()).await
|
||||
@@ -441,6 +449,10 @@ async fn connect_via_proxy(
|
||||
.write_all(connect_req.as_bytes())
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("CONNECT write failed: {e}")))?;
|
||||
stream
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| ProxyError::ForwardFailed(format!("CONNECT flush failed: {e}")))?;
|
||||
|
||||
// Read the proxy's response status line
|
||||
let mut reader = BufReader::new(&mut stream);
|
||||
|
||||
@@ -348,6 +348,8 @@ impl ProviderAdapter for ClaudeAdapter {
|
||||
)]
|
||||
}
|
||||
AuthStrategy::GitHubCopilot => {
|
||||
// 生成请求追踪 ID
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
vec![
|
||||
(
|
||||
HeaderName::from_static("authorization"),
|
||||
@@ -373,9 +375,30 @@ impl ProviderAdapter for ClaudeAdapter {
|
||||
HeaderName::from_static("x-github-api-version"),
|
||||
HeaderValue::from_static(super::copilot_auth::COPILOT_API_VERSION),
|
||||
),
|
||||
// 26-04-01新增的copilot关键 headers
|
||||
(
|
||||
HeaderName::from_static("openai-intent"),
|
||||
HeaderValue::from_static("conversation-panel"),
|
||||
HeaderValue::from_static("conversation-agent"),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("x-initiator"),
|
||||
HeaderValue::from_static("user"),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("x-interaction-type"),
|
||||
HeaderValue::from_static("conversation-agent"),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("x-vscode-user-agent-library-version"),
|
||||
HeaderValue::from_static("electron-fetch"),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("x-request-id"),
|
||||
HeaderValue::from_str(&request_id).unwrap(),
|
||||
),
|
||||
(
|
||||
HeaderName::from_static("x-agent-task-id"),
|
||||
HeaderValue::from_str(&request_id).unwrap(),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -46,15 +46,18 @@ const TOKEN_REFRESH_BUFFER_SECONDS: i64 = 60;
|
||||
const COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
|
||||
|
||||
/// Copilot API Header 常量
|
||||
pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.96.0";
|
||||
pub const COPILOT_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
|
||||
pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
|
||||
pub const COPILOT_API_VERSION: &str = "2025-04-01";
|
||||
pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.110.1";
|
||||
pub const COPILOT_PLUGIN_VERSION: &str = "copilot-chat/0.38.2";
|
||||
pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.38.2";
|
||||
pub const COPILOT_API_VERSION: &str = "2025-10-01";
|
||||
pub const COPILOT_INTEGRATION_ID: &str = "vscode-chat";
|
||||
|
||||
/// Copilot 使用量 API URL
|
||||
const COPILOT_USAGE_URL: &str = "https://api.github.com/copilot_internal/user";
|
||||
|
||||
/// 默认 Copilot API 端点
|
||||
const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
|
||||
|
||||
/// Copilot 使用量响应
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CopilotUsageResponse {
|
||||
@@ -64,6 +67,19 @@ pub struct CopilotUsageResponse {
|
||||
pub quota_reset_date: String,
|
||||
/// 配额快照
|
||||
pub quota_snapshots: QuotaSnapshots,
|
||||
/// API 端点信息 (用于动态获取 API URL)
|
||||
#[serde(default)]
|
||||
pub endpoints: Option<CopilotEndpoints>,
|
||||
}
|
||||
|
||||
/// Copilot API 端点信息
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CopilotEndpoints {
|
||||
/// API 端点 URL
|
||||
pub api: String,
|
||||
/// Telemetry 端点 URL
|
||||
#[serde(default)]
|
||||
pub telemetry: Option<String>,
|
||||
}
|
||||
|
||||
/// 配额快照
|
||||
@@ -312,6 +328,10 @@ pub struct CopilotAuthManager {
|
||||
copilot_tokens: Arc<RwLock<HashMap<String, CopilotToken>>>,
|
||||
/// Copilot Models 缓存(key = GitHub user ID,仅进程内复用)
|
||||
copilot_models: Arc<RwLock<HashMap<String, Vec<CopilotModel>>>>,
|
||||
/// Copilot API 端点缓存(key = GitHub user ID,从 /copilot_internal/user 获取)
|
||||
api_endpoints: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// 每个账号的端点拉取锁,避免并发拉取重复打 GitHub API
|
||||
endpoint_locks: Arc<RwLock<HashMap<String, Arc<Mutex<()>>>>>,
|
||||
/// HTTP 客户端
|
||||
http_client: Client,
|
||||
/// 存储路径
|
||||
@@ -333,6 +353,8 @@ impl CopilotAuthManager {
|
||||
refresh_locks: Arc::new(RwLock::new(HashMap::new())),
|
||||
copilot_tokens: Arc::new(RwLock::new(HashMap::new())),
|
||||
copilot_models: Arc::new(RwLock::new(HashMap::new())),
|
||||
api_endpoints: Arc::new(RwLock::new(HashMap::new())),
|
||||
endpoint_locks: Arc::new(RwLock::new(HashMap::new())),
|
||||
http_client: Client::new(),
|
||||
storage_path,
|
||||
pending_migration: Arc::new(RwLock::new(None)),
|
||||
@@ -386,6 +408,15 @@ impl CopilotAuthManager {
|
||||
let mut refresh_locks = self.refresh_locks.write().await;
|
||||
refresh_locks.remove(account_id);
|
||||
}
|
||||
// 清理 API 端点缓存
|
||||
{
|
||||
let mut api_endpoints = self.api_endpoints.write().await;
|
||||
api_endpoints.remove(account_id);
|
||||
}
|
||||
{
|
||||
let mut endpoint_locks = self.endpoint_locks.write().await;
|
||||
endpoint_locks.remove(account_id);
|
||||
}
|
||||
|
||||
{
|
||||
let accounts = self.accounts.read().await;
|
||||
@@ -775,6 +806,14 @@ impl CopilotAuthManager {
|
||||
.await
|
||||
.map_err(|e| CopilotAuthError::ParseError(e.to_string()))?;
|
||||
|
||||
// 存储动态 API 端点(如果有)
|
||||
if let Some(ref endpoints) = usage.endpoints {
|
||||
let mut api_endpoints = self.api_endpoints.write().await;
|
||||
api_endpoints.insert(account_id.to_string(), endpoints.api.clone());
|
||||
// 使用 debug 级别避免在日志中暴露企业内部域名
|
||||
log::debug!("[CopilotAuth] 账号 {account_id} 已保存动态 API 端点");
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"[CopilotAuth] 获取使用量成功,计划: {}, 重置日期: {}",
|
||||
usage.copilot_plan,
|
||||
@@ -794,6 +833,118 @@ impl CopilotAuthManager {
|
||||
|
||||
// ==================== 状态查询 ====================
|
||||
|
||||
/// 获取指定账号的 API 端点(缓存命中直接返回,未命中则从 API 惰性拉取)
|
||||
pub async fn get_api_endpoint(&self, account_id: &str) -> String {
|
||||
let _ = self.ensure_migration_complete().await;
|
||||
|
||||
{
|
||||
let endpoints = self.api_endpoints.read().await;
|
||||
if let Some(endpoint) = endpoints.get(account_id) {
|
||||
return endpoint.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// 用锁串行化同一账号的并发拉取,避免对 GitHub API 的重复请求
|
||||
let lock = self.get_endpoint_lock(account_id).await;
|
||||
let _guard = lock.lock().await;
|
||||
|
||||
// 持锁后二次检查:可能已由其他请求填充
|
||||
{
|
||||
let endpoints = self.api_endpoints.read().await;
|
||||
if let Some(endpoint) = endpoints.get(account_id) {
|
||||
return endpoint.clone();
|
||||
}
|
||||
}
|
||||
|
||||
match self.fetch_and_cache_endpoint(account_id).await {
|
||||
Ok(endpoint) => endpoint,
|
||||
Err(e) => {
|
||||
log::debug!(
|
||||
"[CopilotAuth] 获取账号 {account_id} 动态 API 端点失败: {e},使用默认值"
|
||||
);
|
||||
DEFAULT_COPILOT_API_ENDPOINT.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取默认账号的 API 端点
|
||||
pub async fn get_default_api_endpoint(&self) -> String {
|
||||
let _ = self.ensure_migration_complete().await;
|
||||
|
||||
match self.resolve_default_account_id().await {
|
||||
Some(id) => self.get_api_endpoint(&id).await,
|
||||
None => DEFAULT_COPILOT_API_ENDPOINT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_and_cache_endpoint(&self, account_id: &str) -> Result<String, CopilotAuthError> {
|
||||
let github_token = {
|
||||
let accounts = self.accounts.read().await;
|
||||
accounts
|
||||
.get(account_id)
|
||||
.map(|a| a.github_token.clone())
|
||||
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?
|
||||
};
|
||||
|
||||
log::debug!("[CopilotAuth] 为账号 {account_id} 惰性拉取动态 API 端点");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(COPILOT_USAGE_URL)
|
||||
.header("Authorization", format!("token {github_token}"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("editor-version", COPILOT_EDITOR_VERSION)
|
||||
.header("editor-plugin-version", COPILOT_PLUGIN_VERSION)
|
||||
.header("user-agent", COPILOT_USER_AGENT)
|
||||
.header("x-github-api-version", COPILOT_API_VERSION)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
|
||||
return Err(CopilotAuthError::GitHubTokenInvalid);
|
||||
}
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CopilotAuthError::CopilotTokenFetchFailed(format!(
|
||||
"获取 API 端点失败: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let usage: CopilotUsageResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| CopilotAuthError::ParseError(e.to_string()))?;
|
||||
|
||||
let endpoint = match usage.endpoints {
|
||||
Some(endpoints) => endpoints.api.clone(),
|
||||
None => DEFAULT_COPILOT_API_ENDPOINT.to_string(),
|
||||
};
|
||||
|
||||
// 缓存端点(包括默认值),避免重复请求
|
||||
let mut api_endpoints = self.api_endpoints.write().await;
|
||||
api_endpoints.insert(account_id.to_string(), endpoint.clone());
|
||||
log::debug!("[CopilotAuth] 账号 {account_id} 已缓存 API 端点");
|
||||
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
async fn get_endpoint_lock(&self, account_id: &str) -> Arc<Mutex<()>> {
|
||||
{
|
||||
let locks = self.endpoint_locks.read().await;
|
||||
if let Some(lock) = locks.get(account_id) {
|
||||
return Arc::clone(lock);
|
||||
}
|
||||
}
|
||||
|
||||
let mut locks = self.endpoint_locks.write().await;
|
||||
Arc::clone(
|
||||
locks
|
||||
.entry(account_id.to_string())
|
||||
.or_insert_with(|| Arc::new(Mutex::new(()))),
|
||||
)
|
||||
}
|
||||
|
||||
/// 获取认证状态(支持多账号)
|
||||
pub async fn get_status(&self) -> CopilotAuthStatus {
|
||||
// 确保迁移完成
|
||||
@@ -838,6 +989,7 @@ impl CopilotAuthManager {
|
||||
pub async fn clear_auth(&self) -> Result<(), CopilotAuthError> {
|
||||
log::info!("[CopilotAuth] 清除所有认证");
|
||||
|
||||
// 先清理内存状态,确保即使文件删除失败用户也能看到已登出
|
||||
{
|
||||
let mut accounts = self.accounts.write().await;
|
||||
accounts.clear();
|
||||
@@ -851,12 +1003,25 @@ impl CopilotAuthManager {
|
||||
let mut tokens = self.copilot_tokens.write().await;
|
||||
tokens.clear();
|
||||
}
|
||||
{
|
||||
let mut models = self.copilot_models.write().await;
|
||||
models.clear();
|
||||
}
|
||||
{
|
||||
let mut refresh_locks = self.refresh_locks.write().await;
|
||||
refresh_locks.clear();
|
||||
}
|
||||
// 清理 API 端点缓存
|
||||
{
|
||||
let mut api_endpoints = self.api_endpoints.write().await;
|
||||
api_endpoints.clear();
|
||||
}
|
||||
{
|
||||
let mut endpoint_locks = self.endpoint_locks.write().await;
|
||||
endpoint_locks.clear();
|
||||
}
|
||||
|
||||
// 删除存储文件
|
||||
// 最后删除存储文件
|
||||
if self.storage_path.exists() {
|
||||
std::fs::remove_file(&self.storage_path)?;
|
||||
}
|
||||
@@ -1414,4 +1579,242 @@ mod tests {
|
||||
let default_vendor = manager.get_model_vendor("claude-sonnet-4").await.unwrap();
|
||||
assert_eq!(default_vendor.as_deref(), Some("Anthropic"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_api_endpoint_returns_cached_value() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// 手动设置 api_endpoints 缓存
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert(
|
||||
"12345".to_string(),
|
||||
"https://copilot-api.enterprise.example.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let endpoint = manager.get_api_endpoint("12345").await;
|
||||
assert_eq!(endpoint, "https://copilot-api.enterprise.example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_api_endpoint_returns_default_when_not_cached() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let endpoint = manager.get_api_endpoint("99999").await;
|
||||
assert_eq!(endpoint, "https://api.githubcopilot.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_default_api_endpoint_uses_default_account() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// 设置默认账号
|
||||
{
|
||||
let mut default_account_id = manager.default_account_id.write().await;
|
||||
*default_account_id = Some("12345".to_string());
|
||||
}
|
||||
// 添加账号数据
|
||||
{
|
||||
let mut accounts = manager.accounts.write().await;
|
||||
accounts.insert(
|
||||
"12345".to_string(),
|
||||
GitHubAccountData {
|
||||
github_token: "gho_test".to_string(),
|
||||
user: GitHubUser {
|
||||
login: "alice".to_string(),
|
||||
id: 12345,
|
||||
avatar_url: None,
|
||||
},
|
||||
authenticated_at: 1700000000,
|
||||
},
|
||||
);
|
||||
}
|
||||
// 设置 API endpoint 缓存
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert(
|
||||
"12345".to_string(),
|
||||
"https://copilot-api.corp.example.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let endpoint = manager.get_default_api_endpoint().await;
|
||||
assert_eq!(endpoint, "https://copilot-api.corp.example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_account_clears_api_endpoint_cache() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// 添加账号数据
|
||||
{
|
||||
let mut accounts = manager.accounts.write().await;
|
||||
accounts.insert(
|
||||
"12345".to_string(),
|
||||
GitHubAccountData {
|
||||
github_token: "gho_test".to_string(),
|
||||
user: GitHubUser {
|
||||
login: "alice".to_string(),
|
||||
id: 12345,
|
||||
avatar_url: None,
|
||||
},
|
||||
authenticated_at: 1700000000,
|
||||
},
|
||||
);
|
||||
}
|
||||
// 设置 API endpoint 缓存
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert(
|
||||
"12345".to_string(),
|
||||
"https://copilot-api.enterprise.example.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// 确认缓存存在
|
||||
{
|
||||
let api_endpoints = manager.api_endpoints.read().await;
|
||||
assert!(api_endpoints.contains_key("12345"));
|
||||
}
|
||||
|
||||
// 移除账号
|
||||
manager.remove_account("12345").await.unwrap();
|
||||
|
||||
// 确认缓存已清理
|
||||
{
|
||||
let api_endpoints = manager.api_endpoints.read().await;
|
||||
assert!(!api_endpoints.contains_key("12345"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_auth_clears_all_api_endpoint_cache() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// 添加多个账号的 API endpoint 缓存
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert(
|
||||
"12345".to_string(),
|
||||
"https://copilot-api.enterprise1.example.com".to_string(),
|
||||
);
|
||||
api_endpoints.insert(
|
||||
"67890".to_string(),
|
||||
"https://copilot-api.enterprise2.example.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// 确认缓存存在
|
||||
{
|
||||
let api_endpoints = manager.api_endpoints.read().await;
|
||||
assert_eq!(api_endpoints.len(), 2);
|
||||
}
|
||||
|
||||
// 清除所有认证
|
||||
manager.clear_auth().await.unwrap();
|
||||
|
||||
// 确认缓存已清空
|
||||
{
|
||||
let api_endpoints = manager.api_endpoints.read().await;
|
||||
assert!(api_endpoints.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_auth_cleans_memory_even_when_file_removal_fails() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
// Create a directory at storage_path so remove_file fails
|
||||
std::fs::create_dir_all(&manager.storage_path).unwrap();
|
||||
|
||||
{
|
||||
let mut accounts = manager.accounts.write().await;
|
||||
accounts.insert(
|
||||
"12345".to_string(),
|
||||
GitHubAccountData {
|
||||
github_token: "gho_test".to_string(),
|
||||
user: GitHubUser {
|
||||
login: "alice".to_string(),
|
||||
id: 12345,
|
||||
avatar_url: None,
|
||||
},
|
||||
authenticated_at: 1700000000,
|
||||
},
|
||||
);
|
||||
}
|
||||
{
|
||||
let mut default_account_id = manager.default_account_id.write().await;
|
||||
*default_account_id = Some("12345".to_string());
|
||||
}
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert(
|
||||
"12345".to_string(),
|
||||
"https://copilot-api.enterprise.example.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let result = manager.clear_auth().await;
|
||||
// Should still return an error for the file deletion failure
|
||||
assert!(result.is_err());
|
||||
|
||||
// But memory state should already be cleaned
|
||||
let accounts = manager.accounts.read().await;
|
||||
assert!(accounts.is_empty());
|
||||
drop(accounts);
|
||||
|
||||
let default_account_id = manager.default_account_id.read().await;
|
||||
assert!(default_account_id.is_none());
|
||||
drop(default_account_id);
|
||||
|
||||
let api_endpoints = manager.api_endpoints.read().await;
|
||||
assert!(api_endpoints.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_api_endpoint_cache_hit_skips_fetch() {
|
||||
// 缓存命中时应直接返回,不发起网络请求
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let enterprise_endpoint = "https://copilot-api.enterprise.example.com".to_string();
|
||||
{
|
||||
let mut api_endpoints = manager.api_endpoints.write().await;
|
||||
api_endpoints.insert("12345".to_string(), enterprise_endpoint.clone());
|
||||
}
|
||||
|
||||
// 即使没有账号数据,缓存命中也应直接返回
|
||||
let endpoint = manager.get_api_endpoint("12345").await;
|
||||
assert_eq!(endpoint, enterprise_endpoint);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_api_endpoint_returns_default_for_unknown_account() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let endpoint = manager.get_api_endpoint("12345").await;
|
||||
assert_eq!(endpoint, DEFAULT_COPILOT_API_ENDPOINT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_and_cache_endpoint_requires_account() {
|
||||
// 账号不存在时 fetch_and_cache_endpoint 应返回 AccountNotFound 错误
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let result = manager.fetch_and_cache_endpoint("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
CopilotAuthError::AccountNotFound(id) => assert_eq!(id, "nonexistent"),
|
||||
other => panic!("期望 AccountNotFound 错误,实际: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,7 @@ impl StreamCheckService {
|
||||
provider: &Provider,
|
||||
config: &StreamCheckConfig,
|
||||
auth_override: Option<AuthInfo>,
|
||||
base_url_override: Option<String>,
|
||||
claude_api_format_override: Option<String>,
|
||||
) -> Result<StreamCheckResult, AppError> {
|
||||
// 合并供应商单独配置和全局配置
|
||||
@@ -100,6 +101,7 @@ impl StreamCheckService {
|
||||
provider,
|
||||
&effective_config,
|
||||
auth_override.clone(),
|
||||
base_url_override.clone(),
|
||||
claude_api_format_override.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -191,14 +193,18 @@ impl StreamCheckService {
|
||||
provider: &Provider,
|
||||
config: &StreamCheckConfig,
|
||||
auth_override: Option<AuthInfo>,
|
||||
base_url_override: Option<String>,
|
||||
claude_api_format_override: Option<String>,
|
||||
) -> Result<StreamCheckResult, AppError> {
|
||||
let start = Instant::now();
|
||||
let adapter = get_adapter(app_type);
|
||||
|
||||
let base_url = adapter
|
||||
.extract_base_url(provider)
|
||||
.map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?;
|
||||
let base_url = match base_url_override {
|
||||
Some(base_url) => base_url,
|
||||
None => adapter
|
||||
.extract_base_url(provider)
|
||||
.map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?,
|
||||
};
|
||||
|
||||
let auth = auth_override
|
||||
.or_else(|| adapter.extract_auth(provider))
|
||||
@@ -364,6 +370,8 @@ impl StreamCheckService {
|
||||
let mut request_builder = client.post(&url);
|
||||
|
||||
if is_github_copilot {
|
||||
// 生成请求追踪 ID
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
request_builder = request_builder
|
||||
.header("authorization", format!("Bearer {}", auth.api_key))
|
||||
.header("content-type", "application/json")
|
||||
@@ -380,7 +388,13 @@ impl StreamCheckService {
|
||||
copilot_auth::COPILOT_INTEGRATION_ID,
|
||||
)
|
||||
.header("x-github-api-version", copilot_auth::COPILOT_API_VERSION)
|
||||
.header("openai-intent", "conversation-panel");
|
||||
// 260401 新增copilot 的关键 headers
|
||||
.header("openai-intent", "conversation-agent")
|
||||
.header("x-initiator", "user")
|
||||
.header("x-interaction-type", "conversation-agent")
|
||||
.header("x-vscode-user-agent-library-version", "electron-fetch")
|
||||
.header("x-request-id", &request_id)
|
||||
.header("x-agent-task-id", &request_id);
|
||||
} else if is_openai_chat || is_openai_responses {
|
||||
// OpenAI-compatible targets: Bearer auth + SSE headers only
|
||||
request_builder = request_builder
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
User,
|
||||
} from "lucide-react";
|
||||
import { useCopilotAuth } from "./hooks/useCopilotAuth";
|
||||
import { copyText } from "@/lib/clipboard";
|
||||
import type { GitHubAccount } from "@/lib/api";
|
||||
|
||||
interface CopilotAuthSectionProps {
|
||||
@@ -67,7 +68,7 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
|
||||
// 复制用户码
|
||||
const copyUserCode = async () => {
|
||||
if (deviceCode?.user_code) {
|
||||
await navigator.clipboard.writeText(deviceCode.user_code);
|
||||
await copyText(deviceCode.user_code);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useCallback, useRef, useEffect } from "react";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { authApi, settingsApi } from "@/lib/api";
|
||||
import { copyText } from "@/lib/clipboard";
|
||||
import type {
|
||||
ManagedAuthProvider,
|
||||
ManagedAuthStatus,
|
||||
@@ -58,7 +59,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) {
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(response.user_code);
|
||||
await copyText(response.user_code);
|
||||
} catch (e) {
|
||||
console.debug("[ManagedAuth] Failed to copy user code:", e);
|
||||
}
|
||||
|
||||
19
src/lib/clipboard.ts
Normal file
19
src/lib/clipboard.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { invoke } from "@tauri-apps/api/core";
|
||||
|
||||
export async function copyText(text: string): Promise<void> {
|
||||
try {
|
||||
await invoke("copy_text_to_clipboard", { text });
|
||||
return;
|
||||
} catch (nativeError) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
return;
|
||||
} catch (webError) {
|
||||
throw webError instanceof Error
|
||||
? webError
|
||||
: nativeError instanceof Error
|
||||
? nativeError
|
||||
: new Error(String(webError || nativeError));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user