diff --git a/Cargo.lock b/Cargo.lock index 8ec0db929036d5053750be078ab3ea7b2d481c36..854a74f25adc1337da20667414e0495ff3d78911 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5908,6 +5908,7 @@ dependencies = [ "criterion", "ctor", "dap", + "dirs 4.0.0", "editor", "extension", "fs", @@ -5936,6 +5937,7 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "smol", "task", "telemetry", "tempfile", diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index ac0827cd014d77b582aaf7db2da8fbc55a05f957..1bb2a84059fb309cad8736a83133b7754dbed4bc 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -17,8 +17,9 @@ pub use serde_json; pub use wit::{ CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars, KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file, - llm_delete_credential, llm_get_credential, llm_get_env_var, llm_request_credential, - llm_store_credential, make_file_executable, + llm_delete_credential, llm_get_credential, llm_get_env_var, llm_oauth_http_request, + llm_oauth_open_browser, llm_oauth_start_web_auth, llm_request_credential, llm_store_credential, + make_file_executable, zed::extension::context_server::ContextServerConfiguration, zed::extension::dap::{ AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, BuildTaskTemplate, @@ -35,7 +36,9 @@ pub use wit::{ CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType, ImageData as LlmImageData, MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, - ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo, + ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest, + OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig, + OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index a5c9357b9c80b70f0bf362ba04cd581d52f67828..0f3d1eefee9e04e77ea6cbbea3249f44c4efd504 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -24,6 +24,7 @@ client.workspace = true collections.workspace = true credentials_provider.workspace = true dap.workspace = true +dirs.workspace = true editor.workspace = true extension.workspace = true fs.workspace = true @@ -48,6 +49,7 @@ serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true +smol.workspace = true task.workspace = true telemetry.workspace = true tempfile.workspace = true diff --git a/crates/extension_host/src/copilot_migration.rs b/crates/extension_host/src/copilot_migration.rs new file mode 100644 index 0000000000000000000000000000000000000000..90fdf48c0de69c14a560d4600fd4d24986891d2d --- /dev/null +++ b/crates/extension_host/src/copilot_migration.rs @@ -0,0 +1,161 @@ +use credentials_provider::CredentialsProvider; +use gpui::App; +use std::path::PathBuf; + +const COPILOT_CHAT_EXTENSION_ID: &str = "copilot_chat"; +const COPILOT_CHAT_PROVIDER_ID: &str = "copilot_chat"; + +pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) { + if extension_id != COPILOT_CHAT_EXTENSION_ID { + return; + } + + let credential_key = format!( + "extension-llm-{}:{}", + COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID + ); + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + let existing_credential = credentials_provider + .read_credentials(&credential_key, &cx) + .await + .ok() + .flatten(); + + if existing_credential.is_some() { + log::debug!("Copilot Chat extension already has credentials, skipping migration"); + return; + } + + let oauth_token = match read_copilot_oauth_token().await { + Some(token) => token, + None => { + log::debug!("No existing Copilot OAuth token found to migrate"); + return; + } + }; + + log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension"); + + match credentials_provider + .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx) + .await + { + Ok(()) => { + log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension"); + } + Err(err) => { + log::error!("Failed to migrate Copilot OAuth token: {}", err); + } + } + }) + .detach(); +} + +async fn read_copilot_oauth_token() -> Option { + let config_paths = copilot_config_paths(); + + for path in config_paths { + if let Some(token) = read_oauth_token_from_file(&path).await { + return Some(token); + } + } + + None +} + +fn copilot_config_paths() -> Vec { + let config_dir = if cfg!(target_os = "windows") { + dirs::data_local_dir() + } else { + std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .ok() + .or_else(|| dirs::home_dir().map(|h| h.join(".config"))) + }; + + let Some(config_dir) = config_dir else { + return Vec::new(); + }; + + let copilot_dir = config_dir.join("github-copilot"); + + vec![ + copilot_dir.join("hosts.json"), + copilot_dir.join("apps.json"), + ] +} + +async fn read_oauth_token_from_file(path: &PathBuf) -> Option { + let contents = match smol::fs::read_to_string(path).await { + Ok(contents) => contents, + Err(_) => return None, + }; + + extract_oauth_token(&contents, "github.com") +} + +fn extract_oauth_token(contents: &str, domain: &str) -> Option { + let value: serde_json::Value = serde_json::from_str(contents).ok()?; + let obj = value.as_object()?; + + for (key, value) in obj.iter() { + if key.starts_with(domain) { + if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) { + return Some(token.to_string()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_oauth_token() { + let contents = r#"{ + "github.com": { + "oauth_token": "ghu_test_token_12345" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, Some("ghu_test_token_12345".to_string())); + } + + #[test] + fn test_extract_oauth_token_with_prefix() { + let contents = r#"{ + "github.com:user": { + "oauth_token": "ghu_another_token" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, Some("ghu_another_token".to_string())); + } + + #[test] + fn test_extract_oauth_token_missing() { + let contents = r#"{ + "gitlab.com": { + "oauth_token": "some_token" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, None); + } + + #[test] + fn test_extract_oauth_token_invalid_json() { + let contents = "not valid json"; + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, None); + } +} diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 689224dda0e92a6e715950a61601727bc2a7731d..f2feb4e8657b056580d790057a099847282d09f9 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1,4 +1,5 @@ mod capability_granter; +mod copilot_migration; pub mod extension_settings; pub mod headless_host; pub mod wasm_host; @@ -788,6 +789,9 @@ impl ExtensionStore { this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx) }); } + + // Run extension-specific migrations + copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx); }) .ok(); } diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs index b2a6cc8315849d0c8364460011a381eaf041fba0..6d1457bebd4fb1865dfe0e3ae52139ba5d435f2b 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs @@ -24,12 +24,14 @@ use gpui::{BackgroundExecutor, SharedString}; use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings}; use project::project_settings::ProjectSettings; use semver::Version; +use smol::net::TcpListener; use std::{ env, net::Ipv4Addr, path::{Path, PathBuf}, str::FromStr, sync::{Arc, OnceLock}, + time::Duration, }; use task::{SpawnInTerminal, ZedDebugConfig}; use url::Url; @@ -1247,6 +1249,192 @@ impl ExtensionImports for WasmState { Ok(env::var(&name).ok()) } + + async fn llm_oauth_start_web_auth( + &mut self, + config: llm_provider::OauthWebAuthConfig, + ) -> wasmtime::Result> { + let auth_url = config.auth_url; + let callback_path = config.callback_path; + let timeout_secs = config.timeout_secs.unwrap_or(300); + + self.on_main_thread(move |cx| { + async move { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?; + let port = listener + .local_addr() + .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))? + .port(); + + cx.update(|cx| { + cx.open_url(&auth_url); + })?; + + let accept_future = async { + let (mut stream, _) = listener + .accept() + .await + .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?; + + let mut request_line = String::new(); + { + let mut reader = smol::io::BufReader::new(&mut stream); + smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line) + .await + .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?; + } + + let callback_url = if let Some(path_start) = request_line.find(' ') { + if let Some(path_end) = request_line[path_start + 1..].find(' ') { + let path = &request_line[path_start + 1..path_start + 1 + path_end]; + if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) { + format!("http://localhost:{}{}", port, path) + } else { + return Err(anyhow::anyhow!( + "Unexpected callback path: {}", + path + )); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + }; + + let response = "HTTP/1.1 200 OK\r\n\ + Content-Type: text/html\r\n\ + Connection: close\r\n\ + \r\n\ + \ + Authentication Complete\ + \ +
\ +

Authentication Complete

\ +

You can close this window and return to Zed.

\ +
"; + + smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .ok(); + smol::io::AsyncWriteExt::flush(&mut stream).await.ok(); + + Ok(callback_url) + }; + + let timeout_duration = Duration::from_secs(timeout_secs as u64); + let callback_url = smol::future::or( + accept_future, + async { + smol::Timer::after(timeout_duration).await; + Err(anyhow::anyhow!( + "OAuth callback timed out after {} seconds", + timeout_secs + )) + }, + ) + .await?; + + Ok(llm_provider::OauthWebAuthResult { + callback_url, + port: port as u32, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn llm_oauth_http_request( + &mut self, + request: llm_provider::OauthHttpRequest, + ) -> wasmtime::Result> { + let http_client = self.host.http_client.clone(); + + self.on_main_thread(move |_cx| { + async move { + let method = match request.method.to_uppercase().as_str() { + "GET" => ::http_client::Method::GET, + "POST" => ::http_client::Method::POST, + "PUT" => ::http_client::Method::PUT, + "DELETE" => ::http_client::Method::DELETE, + "PATCH" => ::http_client::Method::PATCH, + _ => { + return Err(anyhow::anyhow!( + "Unsupported HTTP method: {}", + request.method + )); + } + }; + + let mut builder = ::http_client::Request::builder() + .method(method) + .uri(&request.url); + + for (key, value) in &request.headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + let body = if request.body.is_empty() { + AsyncBody::empty() + } else { + AsyncBody::from(request.body.into_bytes()) + }; + + let http_request = builder + .body(body) + .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?; + + let mut response = http_client + .send(http_request) + .await + .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?; + + let status = response.status().as_u16(); + let headers: Vec<(String, String)> = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let mut body_bytes = Vec::new(); + futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes) + .await + .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?; + + let body = String::from_utf8_lossy(&body_bytes).to_string(); + + Ok(llm_provider::OauthHttpResponse { + status, + headers, + body, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn llm_oauth_open_browser( + &mut self, + url: String, + ) -> wasmtime::Result> { + self.on_main_thread(move |cx| { + async move { + cx.update(|cx| { + cx.open_url(&url); + })?; + Ok(()) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } } // ============================================================================= diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index 213a677687a9c2c403ec7d90a462db36f8dc85dd..714caa05ff130159cf842c397ea2e2ed6503ff4f 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -27,7 +27,6 @@ use semver::Version; use smol::net::TcpListener; use std::{ env, - io::{BufRead, Write}, net::Ipv4Addr, path::{Path, PathBuf}, str::FromStr, @@ -1271,16 +1270,18 @@ impl ExtensionImports for WasmState { })?; let accept_future = async { - let (stream, _) = listener + let (mut stream, _) = listener .accept() .await .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?; - let mut reader = smol::io::BufReader::new(&stream); let mut request_line = String::new(); - smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line) - .await - .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?; + { + let mut reader = smol::io::BufReader::new(&mut stream); + smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line) + .await + .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?; + } let callback_url = if let Some(path_start) = request_line.find(' ') { if let Some(path_end) = request_line[path_start + 1..].find(' ') { @@ -1312,11 +1313,10 @@ impl ExtensionImports for WasmState {

You can close this window and return to Zed.

\ "; - let mut writer = &stream; - smol::io::AsyncWriteExt::write_all(&mut writer, response.as_bytes()) + smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) .await .ok(); - smol::io::AsyncWriteExt::flush(&mut writer).await.ok(); + smol::io::AsyncWriteExt::flush(&mut stream).await.ok(); Ok(callback_url) }; @@ -1349,7 +1349,7 @@ impl ExtensionImports for WasmState { &mut self, request: llm_provider::OauthHttpRequest, ) -> wasmtime::Result> { - let http_client = self.http_client.clone(); + let http_client = self.host.http_client.clone(); self.on_main_thread(move |_cx| { async move { @@ -1367,7 +1367,7 @@ impl ExtensionImports for WasmState { } }; - let mut builder = ::http_client::HttpRequest::builder() + let mut builder = ::http_client::Request::builder() .method(method) .uri(&request.url);