Detailed changes
@@ -5908,6 +5908,7 @@ dependencies = [
"criterion",
"ctor",
"dap",
+ "dirs 4.0.0",
"editor",
"extension",
"fs",
@@ -5936,6 +5937,7 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
+ "smol",
"task",
"telemetry",
"tempfile",
@@ -17,8 +17,9 @@ pub use serde_json;
pub use wit::{
CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file,
- llm_delete_credential, llm_get_credential, llm_get_env_var, llm_request_credential,
- llm_store_credential, make_file_executable,
+ llm_delete_credential, llm_get_credential, llm_get_env_var, llm_oauth_http_request,
+ llm_oauth_open_browser, llm_oauth_start_web_auth, llm_request_credential, llm_store_credential,
+ make_file_executable,
zed::extension::context_server::ContextServerConfiguration,
zed::extension::dap::{
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, BuildTaskTemplate,
@@ -35,7 +36,9 @@ pub use wit::{
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
ImageData as LlmImageData, MessageContent as LlmMessageContent,
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
- ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
+ ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
+ OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
+ OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
@@ -24,6 +24,7 @@ client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
dap.workspace = true
+dirs.workspace = true
editor.workspace = true
extension.workspace = true
fs.workspace = true
@@ -48,6 +49,7 @@ serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
+smol.workspace = true
task.workspace = true
telemetry.workspace = true
tempfile.workspace = true
@@ -0,0 +1,161 @@
+use credentials_provider::CredentialsProvider;
+use gpui::App;
+use std::path::PathBuf;
+
+const COPILOT_CHAT_EXTENSION_ID: &str = "copilot_chat";
+const COPILOT_CHAT_PROVIDER_ID: &str = "copilot_chat";
+
+pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
+ if extension_id != COPILOT_CHAT_EXTENSION_ID {
+ return;
+ }
+
+ let credential_key = format!(
+ "extension-llm-{}:{}",
+ COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
+ );
+
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ let existing_credential = credentials_provider
+ .read_credentials(&credential_key, &cx)
+ .await
+ .ok()
+ .flatten();
+
+ if existing_credential.is_some() {
+ log::debug!("Copilot Chat extension already has credentials, skipping migration");
+ return;
+ }
+
+ let oauth_token = match read_copilot_oauth_token().await {
+ Some(token) => token,
+ None => {
+ log::debug!("No existing Copilot OAuth token found to migrate");
+ return;
+ }
+ };
+
+ log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
+
+ match credentials_provider
+ .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx)
+ .await
+ {
+ Ok(()) => {
+ log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
+ }
+ Err(err) => {
+ log::error!("Failed to migrate Copilot OAuth token: {}", err);
+ }
+ }
+ })
+ .detach();
+}
+
+async fn read_copilot_oauth_token() -> Option<String> {
+ let config_paths = copilot_config_paths();
+
+ for path in config_paths {
+ if let Some(token) = read_oauth_token_from_file(&path).await {
+ return Some(token);
+ }
+ }
+
+ None
+}
+
+fn copilot_config_paths() -> Vec<PathBuf> {
+ let config_dir = if cfg!(target_os = "windows") {
+ dirs::data_local_dir()
+ } else {
+ std::env::var("XDG_CONFIG_HOME")
+ .map(PathBuf::from)
+ .ok()
+ .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
+ };
+
+ let Some(config_dir) = config_dir else {
+ return Vec::new();
+ };
+
+ let copilot_dir = config_dir.join("github-copilot");
+
+ vec![
+ copilot_dir.join("hosts.json"),
+ copilot_dir.join("apps.json"),
+ ]
+}
+
+async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
+ let contents = match smol::fs::read_to_string(path).await {
+ Ok(contents) => contents,
+ Err(_) => return None,
+ };
+
+ extract_oauth_token(&contents, "github.com")
+}
+
+fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
+ let value: serde_json::Value = serde_json::from_str(contents).ok()?;
+ let obj = value.as_object()?;
+
+ for (key, value) in obj.iter() {
+ if key.starts_with(domain) {
+ if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
+ return Some(token.to_string());
+ }
+ }
+ }
+
+ None
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_oauth_token() {
+ let contents = r#"{
+ "github.com": {
+ "oauth_token": "ghu_test_token_12345"
+ }
+ }"#;
+
+ let token = extract_oauth_token(contents, "github.com");
+ assert_eq!(token, Some("ghu_test_token_12345".to_string()));
+ }
+
+ #[test]
+ fn test_extract_oauth_token_with_prefix() {
+ let contents = r#"{
+ "github.com:user": {
+ "oauth_token": "ghu_another_token"
+ }
+ }"#;
+
+ let token = extract_oauth_token(contents, "github.com");
+ assert_eq!(token, Some("ghu_another_token".to_string()));
+ }
+
+ #[test]
+ fn test_extract_oauth_token_missing() {
+ let contents = r#"{
+ "gitlab.com": {
+ "oauth_token": "some_token"
+ }
+ }"#;
+
+ let token = extract_oauth_token(contents, "github.com");
+ assert_eq!(token, None);
+ }
+
+ #[test]
+ fn test_extract_oauth_token_invalid_json() {
+ let contents = "not valid json";
+ let token = extract_oauth_token(contents, "github.com");
+ assert_eq!(token, None);
+ }
+}
@@ -1,4 +1,5 @@
mod capability_granter;
+mod copilot_migration;
pub mod extension_settings;
pub mod headless_host;
pub mod wasm_host;
@@ -788,6 +789,9 @@ impl ExtensionStore {
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
});
}
+
+ // Run extension-specific migrations
+ copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
})
.ok();
}
@@ -24,12 +24,14 @@ use gpui::{BackgroundExecutor, SharedString};
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
use project::project_settings::ProjectSettings;
use semver::Version;
+use smol::net::TcpListener;
use std::{
env,
net::Ipv4Addr,
path::{Path, PathBuf},
str::FromStr,
sync::{Arc, OnceLock},
+ time::Duration,
};
use task::{SpawnInTerminal, ZedDebugConfig};
use url::Url;
@@ -1247,6 +1249,192 @@ impl ExtensionImports for WasmState {
Ok(env::var(&name).ok())
}
+
+ async fn llm_oauth_start_web_auth(
+ &mut self,
+ config: llm_provider::OauthWebAuthConfig,
+ ) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
+ let auth_url = config.auth_url;
+ let callback_path = config.callback_path;
+ let timeout_secs = config.timeout_secs.unwrap_or(300);
+
+ self.on_main_thread(move |cx| {
+ async move {
+ let listener = TcpListener::bind("127.0.0.1:0")
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
+ let port = listener
+ .local_addr()
+ .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
+ .port();
+
+ cx.update(|cx| {
+ cx.open_url(&auth_url);
+ })?;
+
+ let accept_future = async {
+ let (mut stream, _) = listener
+ .accept()
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
+
+ let mut request_line = String::new();
+ {
+ let mut reader = smol::io::BufReader::new(&mut stream);
+ smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+ }
+
+ let callback_url = if let Some(path_start) = request_line.find(' ') {
+ if let Some(path_end) = request_line[path_start + 1..].find(' ') {
+ let path = &request_line[path_start + 1..path_start + 1 + path_end];
+ if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
+ format!("http://localhost:{}{}", port, path)
+ } else {
+ return Err(anyhow::anyhow!(
+ "Unexpected callback path: {}",
+ path
+ ));
+ }
+ } else {
+ return Err(anyhow::anyhow!("Malformed HTTP request"));
+ }
+ } else {
+ return Err(anyhow::anyhow!("Malformed HTTP request"));
+ };
+
+ let response = "HTTP/1.1 200 OK\r\n\
+ Content-Type: text/html\r\n\
+ Connection: close\r\n\
+ \r\n\
+ <!DOCTYPE html>\
+ <html><head><title>Authentication Complete</title></head>\
+ <body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
+ <div style=\"text-align: center;\">\
+ <h1>Authentication Complete</h1>\
+ <p>You can close this window and return to Zed.</p>\
+ </div></body></html>";
+
+ smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
+ .await
+ .ok();
+ smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
+
+ Ok(callback_url)
+ };
+
+ let timeout_duration = Duration::from_secs(timeout_secs as u64);
+ let callback_url = smol::future::or(
+ accept_future,
+ async {
+ smol::Timer::after(timeout_duration).await;
+ Err(anyhow::anyhow!(
+ "OAuth callback timed out after {} seconds",
+ timeout_secs
+ ))
+ },
+ )
+ .await?;
+
+ Ok(llm_provider::OauthWebAuthResult {
+ callback_url,
+ port: port as u32,
+ })
+ }
+ .boxed_local()
+ })
+ .await
+ .to_wasmtime_result()
+ }
+
+ async fn llm_oauth_http_request(
+ &mut self,
+ request: llm_provider::OauthHttpRequest,
+ ) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
+ let http_client = self.host.http_client.clone();
+
+ self.on_main_thread(move |_cx| {
+ async move {
+ let method = match request.method.to_uppercase().as_str() {
+ "GET" => ::http_client::Method::GET,
+ "POST" => ::http_client::Method::POST,
+ "PUT" => ::http_client::Method::PUT,
+ "DELETE" => ::http_client::Method::DELETE,
+ "PATCH" => ::http_client::Method::PATCH,
+ _ => {
+ return Err(anyhow::anyhow!(
+ "Unsupported HTTP method: {}",
+ request.method
+ ));
+ }
+ };
+
+ let mut builder = ::http_client::Request::builder()
+ .method(method)
+ .uri(&request.url);
+
+ for (key, value) in &request.headers {
+ builder = builder.header(key.as_str(), value.as_str());
+ }
+
+ let body = if request.body.is_empty() {
+ AsyncBody::empty()
+ } else {
+ AsyncBody::from(request.body.into_bytes())
+ };
+
+ let http_request = builder
+ .body(body)
+ .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
+
+ let mut response = http_client
+ .send(http_request)
+ .await
+ .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
+
+ let status = response.status().as_u16();
+ let headers: Vec<(String, String)> = response
+ .headers()
+ .iter()
+ .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
+ .collect();
+
+ let mut body_bytes = Vec::new();
+ futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
+
+ let body = String::from_utf8_lossy(&body_bytes).to_string();
+
+ Ok(llm_provider::OauthHttpResponse {
+ status,
+ headers,
+ body,
+ })
+ }
+ .boxed_local()
+ })
+ .await
+ .to_wasmtime_result()
+ }
+
+ async fn llm_oauth_open_browser(
+ &mut self,
+ url: String,
+ ) -> wasmtime::Result<Result<(), String>> {
+ self.on_main_thread(move |cx| {
+ async move {
+ cx.update(|cx| {
+ cx.open_url(&url);
+ })?;
+ Ok(())
+ }
+ .boxed_local()
+ })
+ .await
+ .to_wasmtime_result()
+ }
}
// =============================================================================
@@ -27,7 +27,6 @@ use semver::Version;
use smol::net::TcpListener;
use std::{
env,
- io::{BufRead, Write},
net::Ipv4Addr,
path::{Path, PathBuf},
str::FromStr,
@@ -1271,16 +1270,18 @@ impl ExtensionImports for WasmState {
})?;
let accept_future = async {
- let (stream, _) = listener
+ let (mut stream, _) = listener
.accept()
.await
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
- let mut reader = smol::io::BufReader::new(&stream);
let mut request_line = String::new();
- smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
- .await
- .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+ {
+ let mut reader = smol::io::BufReader::new(&mut stream);
+ smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
+ .await
+ .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+ }
let callback_url = if let Some(path_start) = request_line.find(' ') {
if let Some(path_end) = request_line[path_start + 1..].find(' ') {
@@ -1312,11 +1313,10 @@ impl ExtensionImports for WasmState {
<p>You can close this window and return to Zed.</p>\
</div></body></html>";
- let mut writer = &stream;
- smol::io::AsyncWriteExt::write_all(&mut writer, response.as_bytes())
+ smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
.await
.ok();
- smol::io::AsyncWriteExt::flush(&mut writer).await.ok();
+ smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
Ok(callback_url)
};
@@ -1349,7 +1349,7 @@ impl ExtensionImports for WasmState {
&mut self,
request: llm_provider::OauthHttpRequest,
) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
- let http_client = self.http_client.clone();
+ let http_client = self.host.http_client.clone();
self.on_main_thread(move |_cx| {
async move {
@@ -1367,7 +1367,7 @@ impl ExtensionImports for WasmState {
}
};
- let mut builder = ::http_client::HttpRequest::builder()
+ let mut builder = ::http_client::Request::builder()
.method(method)
.uri(&request.url);