From 2d3a3521baaedcd2e5ce34f3883d3b59e37aaa2b Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 4 Dec 2025 16:50:44 -0500 Subject: [PATCH] Add OAuth Web Flow auth option for llm provider extensions --- .../wit/since_v0.7.0/extension.wit | 32 +++ .../wit/since_v0.7.0/llm-provider.wit | 47 +++++ .../wit/since_v0.8.0/extension.wit | 32 +++ .../wit/since_v0.8.0/llm-provider.wit | 47 +++++ .../src/wasm_host/wit/since_v0_8_0.rs | 188 ++++++++++++++++++ 5 files changed, 346 insertions(+) diff --git a/crates/extension_api/wit/since_v0.7.0/extension.wit b/crates/extension_api/wit/since_v0.7.0/extension.wit index 92979a8780039776853fa250be2afdb204ae5d55..f95dfa04dac25f792d14f896ee8c00ffa8dcf804 100644 --- a/crates/extension_api/wit/since_v0.7.0/extension.wit +++ b/crates/extension_api/wit/since_v0.7.0/extension.wit @@ -249,4 +249,36 @@ world extension { /// Read an environment variable. import llm-get-env-var: func(name: string) -> option; + + // ========================================================================= + // OAuth Web Auth Flow Imports + // ========================================================================= + + use llm-provider.{oauth-web-auth-config, oauth-web-auth-result, oauth-http-request, oauth-http-response}; + + /// Start an OAuth web authentication flow. + /// + /// This will: + /// 1. Start a localhost server to receive the OAuth callback + /// 2. Open the auth URL in the user's default browser + /// 3. Wait for the callback (up to the timeout) + /// 4. Return the callback URL with query parameters + /// + /// The extension is responsible for: + /// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc. + /// - Parsing the callback URL to extract the authorization code + /// - Exchanging the code for tokens using llm-oauth-http-request + import llm-oauth-start-web-auth: func(config: oauth-web-auth-config) -> result; + + /// Make an HTTP request for OAuth token exchange. + /// + /// This is a simple HTTP client for OAuth flows, allowing the extension + /// to handle token exchange with full control over serialization. + import llm-oauth-http-request: func(request: oauth-http-request) -> result; + + /// Open a URL in the user's default browser. + /// + /// Useful for OAuth flows that need to open a browser but handle the + /// callback differently (e.g., polling-based flows). + import llm-oauth-open-browser: func(url: string) -> result<_, string>; } diff --git a/crates/extension_api/wit/since_v0.7.0/llm-provider.wit b/crates/extension_api/wit/since_v0.7.0/llm-provider.wit index 5912654ebcf9e517e683d13ad2b5e6d9096095eb..aec6569c2efda70faa38524e458951de732dc328 100644 --- a/crates/extension_api/wit/since_v0.7.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.7.0/llm-provider.wit @@ -252,4 +252,51 @@ interface llm-provider { /// Minimum token count for a message to be cached. min-total-token-count: u64, } + + // ========================================================================= + // OAuth Web Auth Flow Types + // ========================================================================= + + /// Configuration for starting an OAuth web authentication flow. + record oauth-web-auth-config { + /// The URL to open in the user's browser to start authentication. + /// This should include client_id, redirect_uri, scope, state, etc. + auth-url: string, + /// The path to listen on for the OAuth callback (e.g., "/callback"). + /// A localhost server will be started to receive the redirect. + callback-path: string, + /// Timeout in seconds to wait for the callback (default: 300 = 5 minutes). + timeout-secs: option, + } + + /// Result of an OAuth web authentication flow. + record oauth-web-auth-result { + /// The full callback URL that was received, including query parameters. + /// The extension is responsible for parsing the code, state, etc. + callback-url: string, + /// The port that was used for the localhost callback server. + port: u32, + } + + /// A generic HTTP request for OAuth token exchange. + record oauth-http-request { + /// The URL to request. + url: string, + /// HTTP method (e.g., "POST", "GET"). + method: string, + /// Request headers as key-value pairs. + headers: list>, + /// Request body as a string (for form-encoded or JSON bodies). + body: string, + } + + /// Response from an OAuth HTTP request. + record oauth-http-response { + /// HTTP status code. + status: u16, + /// Response headers as key-value pairs. + headers: list>, + /// Response body as a string. + body: string, + } } diff --git a/crates/extension_api/wit/since_v0.8.0/extension.wit b/crates/extension_api/wit/since_v0.8.0/extension.wit index 92979a8780039776853fa250be2afdb204ae5d55..f95dfa04dac25f792d14f896ee8c00ffa8dcf804 100644 --- a/crates/extension_api/wit/since_v0.8.0/extension.wit +++ b/crates/extension_api/wit/since_v0.8.0/extension.wit @@ -249,4 +249,36 @@ world extension { /// Read an environment variable. import llm-get-env-var: func(name: string) -> option; + + // ========================================================================= + // OAuth Web Auth Flow Imports + // ========================================================================= + + use llm-provider.{oauth-web-auth-config, oauth-web-auth-result, oauth-http-request, oauth-http-response}; + + /// Start an OAuth web authentication flow. + /// + /// This will: + /// 1. Start a localhost server to receive the OAuth callback + /// 2. Open the auth URL in the user's default browser + /// 3. Wait for the callback (up to the timeout) + /// 4. Return the callback URL with query parameters + /// + /// The extension is responsible for: + /// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc. + /// - Parsing the callback URL to extract the authorization code + /// - Exchanging the code for tokens using llm-oauth-http-request + import llm-oauth-start-web-auth: func(config: oauth-web-auth-config) -> result; + + /// Make an HTTP request for OAuth token exchange. + /// + /// This is a simple HTTP client for OAuth flows, allowing the extension + /// to handle token exchange with full control over serialization. + import llm-oauth-http-request: func(request: oauth-http-request) -> result; + + /// Open a URL in the user's default browser. + /// + /// Useful for OAuth flows that need to open a browser but handle the + /// callback differently (e.g., polling-based flows). + import llm-oauth-open-browser: func(url: string) -> result<_, string>; } diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit index 5912654ebcf9e517e683d13ad2b5e6d9096095eb..aec6569c2efda70faa38524e458951de732dc328 100644 --- a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -252,4 +252,51 @@ interface llm-provider { /// Minimum token count for a message to be cached. min-total-token-count: u64, } + + // ========================================================================= + // OAuth Web Auth Flow Types + // ========================================================================= + + /// Configuration for starting an OAuth web authentication flow. + record oauth-web-auth-config { + /// The URL to open in the user's browser to start authentication. + /// This should include client_id, redirect_uri, scope, state, etc. + auth-url: string, + /// The path to listen on for the OAuth callback (e.g., "/callback"). + /// A localhost server will be started to receive the redirect. + callback-path: string, + /// Timeout in seconds to wait for the callback (default: 300 = 5 minutes). + timeout-secs: option, + } + + /// Result of an OAuth web authentication flow. + record oauth-web-auth-result { + /// The full callback URL that was received, including query parameters. + /// The extension is responsible for parsing the code, state, etc. + callback-url: string, + /// The port that was used for the localhost callback server. + port: u32, + } + + /// A generic HTTP request for OAuth token exchange. + record oauth-http-request { + /// The URL to request. + url: string, + /// HTTP method (e.g., "POST", "GET"). + method: string, + /// Request headers as key-value pairs. + headers: list>, + /// Request body as a string (for form-encoded or JSON bodies). + body: string, + } + + /// Response from an OAuth HTTP request. + record oauth-http-response { + /// HTTP status code. + status: u16, + /// Response headers as key-value pairs. + headers: list>, + /// Response body as a string. + body: string, + } } diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index b5984d7a19a462254b606473aa76d8f5d97ab43c..213a677687a9c2c403ec7d90a462db36f8dc85dd 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -24,12 +24,15 @@ use gpui::{BackgroundExecutor, SharedString}; use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings}; use project::project_settings::ProjectSettings; use semver::Version; +use smol::net::TcpListener; use std::{ env, + io::{BufRead, Write}, net::Ipv4Addr, path::{Path, PathBuf}, str::FromStr, sync::{Arc, OnceLock}, + time::Duration, }; use task::{SpawnInTerminal, ZedDebugConfig}; use url::Url; @@ -1244,6 +1247,191 @@ impl ExtensionImports for WasmState { Ok(env::var(&name).ok()) } + + async fn llm_oauth_start_web_auth( + &mut self, + config: llm_provider::OauthWebAuthConfig, + ) -> wasmtime::Result> { + let auth_url = config.auth_url; + let callback_path = config.callback_path; + let timeout_secs = config.timeout_secs.unwrap_or(300); + + self.on_main_thread(move |cx| { + async move { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?; + let port = listener + .local_addr() + .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))? + .port(); + + cx.update(|cx| { + cx.open_url(&auth_url); + })?; + + let accept_future = async { + let (stream, _) = listener + .accept() + .await + .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?; + + let mut reader = smol::io::BufReader::new(&stream); + let mut request_line = String::new(); + smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line) + .await + .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?; + + let callback_url = if let Some(path_start) = request_line.find(' ') { + if let Some(path_end) = request_line[path_start + 1..].find(' ') { + let path = &request_line[path_start + 1..path_start + 1 + path_end]; + if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) { + format!("http://localhost:{}{}", port, path) + } else { + return Err(anyhow::anyhow!( + "Unexpected callback path: {}", + path + )); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + }; + + let response = "HTTP/1.1 200 OK\r\n\ + Content-Type: text/html\r\n\ + Connection: close\r\n\ + \r\n\ + \ + Authentication Complete\ + \ +
\ +

Authentication Complete

\ +

You can close this window and return to Zed.

\ +
"; + + let mut writer = &stream; + smol::io::AsyncWriteExt::write_all(&mut writer, response.as_bytes()) + .await + .ok(); + smol::io::AsyncWriteExt::flush(&mut writer).await.ok(); + + Ok(callback_url) + }; + + let timeout_duration = Duration::from_secs(timeout_secs as u64); + let callback_url = smol::future::or( + accept_future, + async { + smol::Timer::after(timeout_duration).await; + Err(anyhow::anyhow!( + "OAuth callback timed out after {} seconds", + timeout_secs + )) + }, + ) + .await?; + + Ok(llm_provider::OauthWebAuthResult { + callback_url, + port: port as u32, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn llm_oauth_http_request( + &mut self, + request: llm_provider::OauthHttpRequest, + ) -> wasmtime::Result> { + let http_client = self.http_client.clone(); + + self.on_main_thread(move |_cx| { + async move { + let method = match request.method.to_uppercase().as_str() { + "GET" => ::http_client::Method::GET, + "POST" => ::http_client::Method::POST, + "PUT" => ::http_client::Method::PUT, + "DELETE" => ::http_client::Method::DELETE, + "PATCH" => ::http_client::Method::PATCH, + _ => { + return Err(anyhow::anyhow!( + "Unsupported HTTP method: {}", + request.method + )); + } + }; + + let mut builder = ::http_client::HttpRequest::builder() + .method(method) + .uri(&request.url); + + for (key, value) in &request.headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + let body = if request.body.is_empty() { + AsyncBody::empty() + } else { + AsyncBody::from(request.body.into_bytes()) + }; + + let http_request = builder + .body(body) + .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?; + + let mut response = http_client + .send(http_request) + .await + .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?; + + let status = response.status().as_u16(); + let headers: Vec<(String, String)> = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let mut body_bytes = Vec::new(); + futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes) + .await + .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?; + + let body = String::from_utf8_lossy(&body_bytes).to_string(); + + Ok(llm_provider::OauthHttpResponse { + status, + headers, + body, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn llm_oauth_open_browser( + &mut self, + url: String, + ) -> wasmtime::Result> { + self.on_main_thread(move |cx| { + async move { + cx.update(|cx| { + cx.open_url(&url); + })?; + Ok(()) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } } // =============================================================================