Add OAuth via web authentication to llm extensions, migrate copilot

Richard Feldman created

Change summary

Cargo.lock                                              |   2 
crates/extension_api/src/extension_api.rs               |   9 
crates/extension_host/Cargo.toml                        |   2 
crates/extension_host/src/copilot_migration.rs          | 161 +++++++++
crates/extension_host/src/extension_host.rs             |   4 
crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs | 188 +++++++++++
crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs |  22 
7 files changed, 374 insertions(+), 14 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5908,6 +5908,7 @@ dependencies = [
  "criterion",
  "ctor",
  "dap",
+ "dirs 4.0.0",
  "editor",
  "extension",
  "fs",
@@ -5936,6 +5937,7 @@ dependencies = [
  "serde_json",
  "serde_json_lenient",
  "settings",
+ "smol",
  "task",
  "telemetry",
  "tempfile",

crates/extension_api/src/extension_api.rs 🔗

@@ -17,8 +17,9 @@ pub use serde_json;
 pub use wit::{
     CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
     KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file,
-    llm_delete_credential, llm_get_credential, llm_get_env_var, llm_request_credential,
-    llm_store_credential, make_file_executable,
+    llm_delete_credential, llm_get_credential, llm_get_env_var, llm_oauth_http_request,
+    llm_oauth_open_browser, llm_oauth_start_web_auth, llm_request_credential, llm_store_credential,
+    make_file_executable,
     zed::extension::context_server::ContextServerConfiguration,
     zed::extension::dap::{
         AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, BuildTaskTemplate,
@@ -35,7 +36,9 @@ pub use wit::{
         CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
         ImageData as LlmImageData, MessageContent as LlmMessageContent,
         MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
-        ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
+        ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
+        OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
+        OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
         RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
         ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
         ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,

crates/extension_host/Cargo.toml 🔗

@@ -24,6 +24,7 @@ client.workspace = true
 collections.workspace = true
 credentials_provider.workspace = true
 dap.workspace = true
+dirs.workspace = true
 editor.workspace = true
 extension.workspace = true
 fs.workspace = true
@@ -48,6 +49,7 @@ serde.workspace = true
 serde_json.workspace = true
 serde_json_lenient.workspace = true
 settings.workspace = true
+smol.workspace = true
 task.workspace = true
 telemetry.workspace = true
 tempfile.workspace = true

crates/extension_host/src/copilot_migration.rs 🔗

@@ -0,0 +1,161 @@
+use credentials_provider::CredentialsProvider;
+use gpui::App;
+use std::path::PathBuf;
+
+const COPILOT_CHAT_EXTENSION_ID: &str = "copilot_chat";
+const COPILOT_CHAT_PROVIDER_ID: &str = "copilot_chat";
+
+pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
+    if extension_id != COPILOT_CHAT_EXTENSION_ID {
+        return;
+    }
+
+    let credential_key = format!(
+        "extension-llm-{}:{}",
+        COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
+    );
+
+    let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+    cx.spawn(async move |cx| {
+        let existing_credential = credentials_provider
+            .read_credentials(&credential_key, &cx)
+            .await
+            .ok()
+            .flatten();
+
+        if existing_credential.is_some() {
+            log::debug!("Copilot Chat extension already has credentials, skipping migration");
+            return;
+        }
+
+        let oauth_token = match read_copilot_oauth_token().await {
+            Some(token) => token,
+            None => {
+                log::debug!("No existing Copilot OAuth token found to migrate");
+                return;
+            }
+        };
+
+        log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
+
+        match credentials_provider
+            .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx)
+            .await
+        {
+            Ok(()) => {
+                log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
+            }
+            Err(err) => {
+                log::error!("Failed to migrate Copilot OAuth token: {}", err);
+            }
+        }
+    })
+    .detach();
+}
+
+async fn read_copilot_oauth_token() -> Option<String> {
+    let config_paths = copilot_config_paths();
+
+    for path in config_paths {
+        if let Some(token) = read_oauth_token_from_file(&path).await {
+            return Some(token);
+        }
+    }
+
+    None
+}
+
+fn copilot_config_paths() -> Vec<PathBuf> {
+    let config_dir = if cfg!(target_os = "windows") {
+        dirs::data_local_dir()
+    } else {
+        std::env::var("XDG_CONFIG_HOME")
+            .map(PathBuf::from)
+            .ok()
+            .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
+    };
+
+    let Some(config_dir) = config_dir else {
+        return Vec::new();
+    };
+
+    let copilot_dir = config_dir.join("github-copilot");
+
+    vec![
+        copilot_dir.join("hosts.json"),
+        copilot_dir.join("apps.json"),
+    ]
+}
+
+async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
+    let contents = match smol::fs::read_to_string(path).await {
+        Ok(contents) => contents,
+        Err(_) => return None,
+    };
+
+    extract_oauth_token(&contents, "github.com")
+}
+
+fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
+    let value: serde_json::Value = serde_json::from_str(contents).ok()?;
+    let obj = value.as_object()?;
+
+    for (key, value) in obj.iter() {
+        if key.starts_with(domain) {
+            if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
+                return Some(token.to_string());
+            }
+        }
+    }
+
+    None
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_extract_oauth_token() {
+        let contents = r#"{
+            "github.com": {
+                "oauth_token": "ghu_test_token_12345"
+            }
+        }"#;
+
+        let token = extract_oauth_token(contents, "github.com");
+        assert_eq!(token, Some("ghu_test_token_12345".to_string()));
+    }
+
+    #[test]
+    fn test_extract_oauth_token_with_prefix() {
+        let contents = r#"{
+            "github.com:user": {
+                "oauth_token": "ghu_another_token"
+            }
+        }"#;
+
+        let token = extract_oauth_token(contents, "github.com");
+        assert_eq!(token, Some("ghu_another_token".to_string()));
+    }
+
+    #[test]
+    fn test_extract_oauth_token_missing() {
+        let contents = r#"{
+            "gitlab.com": {
+                "oauth_token": "some_token"
+            }
+        }"#;
+
+        let token = extract_oauth_token(contents, "github.com");
+        assert_eq!(token, None);
+    }
+
+    #[test]
+    fn test_extract_oauth_token_invalid_json() {
+        let contents = "not valid json";
+        let token = extract_oauth_token(contents, "github.com");
+        assert_eq!(token, None);
+    }
+}

crates/extension_host/src/extension_host.rs 🔗

@@ -1,4 +1,5 @@
 mod capability_granter;
+mod copilot_migration;
 pub mod extension_settings;
 pub mod headless_host;
 pub mod wasm_host;
@@ -788,6 +789,9 @@ impl ExtensionStore {
                             this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
                         });
                     }
+
+                    // Run extension-specific migrations
+                    copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
                 })
                 .ok();
             }

crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs 🔗

@@ -24,12 +24,14 @@ use gpui::{BackgroundExecutor, SharedString};
 use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
 use project::project_settings::ProjectSettings;
 use semver::Version;
+use smol::net::TcpListener;
 use std::{
     env,
     net::Ipv4Addr,
     path::{Path, PathBuf},
     str::FromStr,
     sync::{Arc, OnceLock},
+    time::Duration,
 };
 use task::{SpawnInTerminal, ZedDebugConfig};
 use url::Url;
@@ -1247,6 +1249,192 @@ impl ExtensionImports for WasmState {
 
         Ok(env::var(&name).ok())
     }
+
+    async fn llm_oauth_start_web_auth(
+        &mut self,
+        config: llm_provider::OauthWebAuthConfig,
+    ) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
+        let auth_url = config.auth_url;
+        let callback_path = config.callback_path;
+        let timeout_secs = config.timeout_secs.unwrap_or(300);
+
+        self.on_main_thread(move |cx| {
+            async move {
+                let listener = TcpListener::bind("127.0.0.1:0")
+                    .await
+                    .map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
+                let port = listener
+                    .local_addr()
+                    .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
+                    .port();
+
+                cx.update(|cx| {
+                    cx.open_url(&auth_url);
+                })?;
+
+                let accept_future = async {
+                    let (mut stream, _) = listener
+                        .accept()
+                        .await
+                        .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
+
+                    let mut request_line = String::new();
+                    {
+                        let mut reader = smol::io::BufReader::new(&mut stream);
+                        smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
+                            .await
+                            .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+                    }
+
+                    let callback_url = if let Some(path_start) = request_line.find(' ') {
+                        if let Some(path_end) = request_line[path_start + 1..].find(' ') {
+                            let path = &request_line[path_start + 1..path_start + 1 + path_end];
+                            if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
+                                format!("http://localhost:{}{}", port, path)
+                            } else {
+                                return Err(anyhow::anyhow!(
+                                    "Unexpected callback path: {}",
+                                    path
+                                ));
+                            }
+                        } else {
+                            return Err(anyhow::anyhow!("Malformed HTTP request"));
+                        }
+                    } else {
+                        return Err(anyhow::anyhow!("Malformed HTTP request"));
+                    };
+
+                    let response = "HTTP/1.1 200 OK\r\n\
+                        Content-Type: text/html\r\n\
+                        Connection: close\r\n\
+                        \r\n\
+                        <!DOCTYPE html>\
+                        <html><head><title>Authentication Complete</title></head>\
+                        <body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
+                        <div style=\"text-align: center;\">\
+                        <h1>Authentication Complete</h1>\
+                        <p>You can close this window and return to Zed.</p>\
+                        </div></body></html>";
+
+                    smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
+                        .await
+                        .ok();
+                    smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
+
+                    Ok(callback_url)
+                };
+
+                let timeout_duration = Duration::from_secs(timeout_secs as u64);
+                let callback_url = smol::future::or(
+                    accept_future,
+                    async {
+                        smol::Timer::after(timeout_duration).await;
+                        Err(anyhow::anyhow!(
+                            "OAuth callback timed out after {} seconds",
+                            timeout_secs
+                        ))
+                    },
+                )
+                .await?;
+
+                Ok(llm_provider::OauthWebAuthResult {
+                    callback_url,
+                    port: port as u32,
+                })
+            }
+            .boxed_local()
+        })
+        .await
+        .to_wasmtime_result()
+    }
+
+    async fn llm_oauth_http_request(
+        &mut self,
+        request: llm_provider::OauthHttpRequest,
+    ) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
+        let http_client = self.host.http_client.clone();
+
+        self.on_main_thread(move |_cx| {
+            async move {
+                let method = match request.method.to_uppercase().as_str() {
+                    "GET" => ::http_client::Method::GET,
+                    "POST" => ::http_client::Method::POST,
+                    "PUT" => ::http_client::Method::PUT,
+                    "DELETE" => ::http_client::Method::DELETE,
+                    "PATCH" => ::http_client::Method::PATCH,
+                    _ => {
+                        return Err(anyhow::anyhow!(
+                            "Unsupported HTTP method: {}",
+                            request.method
+                        ));
+                    }
+                };
+
+                let mut builder = ::http_client::Request::builder()
+                    .method(method)
+                    .uri(&request.url);
+
+                for (key, value) in &request.headers {
+                    builder = builder.header(key.as_str(), value.as_str());
+                }
+
+                let body = if request.body.is_empty() {
+                    AsyncBody::empty()
+                } else {
+                    AsyncBody::from(request.body.into_bytes())
+                };
+
+                let http_request = builder
+                    .body(body)
+                    .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
+
+                let mut response = http_client
+                    .send(http_request)
+                    .await
+                    .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
+
+                let status = response.status().as_u16();
+                let headers: Vec<(String, String)> = response
+                    .headers()
+                    .iter()
+                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
+                    .collect();
+
+                let mut body_bytes = Vec::new();
+                futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
+                    .await
+                    .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
+
+                let body = String::from_utf8_lossy(&body_bytes).to_string();
+
+                Ok(llm_provider::OauthHttpResponse {
+                    status,
+                    headers,
+                    body,
+                })
+            }
+            .boxed_local()
+        })
+        .await
+        .to_wasmtime_result()
+    }
+
+    async fn llm_oauth_open_browser(
+        &mut self,
+        url: String,
+    ) -> wasmtime::Result<Result<(), String>> {
+        self.on_main_thread(move |cx| {
+            async move {
+                cx.update(|cx| {
+                    cx.open_url(&url);
+                })?;
+                Ok(())
+            }
+            .boxed_local()
+        })
+        .await
+        .to_wasmtime_result()
+    }
 }
 
 // =============================================================================

crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs 🔗

@@ -27,7 +27,6 @@ use semver::Version;
 use smol::net::TcpListener;
 use std::{
     env,
-    io::{BufRead, Write},
     net::Ipv4Addr,
     path::{Path, PathBuf},
     str::FromStr,
@@ -1271,16 +1270,18 @@ impl ExtensionImports for WasmState {
                 })?;
 
                 let accept_future = async {
-                    let (stream, _) = listener
+                    let (mut stream, _) = listener
                         .accept()
                         .await
                         .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
 
-                    let mut reader = smol::io::BufReader::new(&stream);
                     let mut request_line = String::new();
-                    smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
-                        .await
-                        .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+                    {
+                        let mut reader = smol::io::BufReader::new(&mut stream);
+                        smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
+                            .await
+                            .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
+                    }
 
                     let callback_url = if let Some(path_start) = request_line.find(' ') {
                         if let Some(path_end) = request_line[path_start + 1..].find(' ') {
@@ -1312,11 +1313,10 @@ impl ExtensionImports for WasmState {
                         <p>You can close this window and return to Zed.</p>\
                         </div></body></html>";
 
-                    let mut writer = &stream;
-                    smol::io::AsyncWriteExt::write_all(&mut writer, response.as_bytes())
+                    smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
                         .await
                         .ok();
-                    smol::io::AsyncWriteExt::flush(&mut writer).await.ok();
+                    smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
 
                     Ok(callback_url)
                 };
@@ -1349,7 +1349,7 @@ impl ExtensionImports for WasmState {
         &mut self,
         request: llm_provider::OauthHttpRequest,
     ) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
-        let http_client = self.http_client.clone();
+        let http_client = self.host.http_client.clone();
 
         self.on_main_thread(move |_cx| {
             async move {
@@ -1367,7 +1367,7 @@ impl ExtensionImports for WasmState {
                     }
                 };
 
-                let mut builder = ::http_client::HttpRequest::builder()
+                let mut builder = ::http_client::Request::builder()
                     .method(method)
                     .uri(&request.url);