Replace OAuth-specific HTTP WIT stuff with generic http alternatives.

Richard Feldman created

Change summary

crates/extension_api/src/extension_api.rs               |   3 
crates/extension_api/src/http_client.rs                 |   8 
crates/extension_api/wit/since_v0.8.0/http-client.wit   |  17 +
crates/extension_api/wit/since_v0.8.0/llm-provider.wit  |  33 --
crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs | 103 ++++------
extensions/copilot-chat/src/copilot_chat.rs             |  27 +-
6 files changed, 88 insertions(+), 103 deletions(-)

Detailed changes

crates/extension_api/src/extension_api.rs 🔗

@@ -34,8 +34,7 @@ pub use wit::{
         CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo,
         ImageData as LlmImageData, MessageContent as LlmMessageContent,
         MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
-        ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
-        OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
+        ModelInfo as LlmModelInfo, OauthWebAuthConfig as LlmOauthWebAuthConfig,
         OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
         RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
         ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,

crates/extension_api/src/http_client.rs 🔗

@@ -1,7 +1,8 @@
 //! An HTTP client.
 
 pub use crate::wit::zed::extension::http_client::{
-    HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, RedirectPolicy, fetch, fetch_stream,
+    HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, HttpResponseWithStatus,
+    RedirectPolicy, fetch, fetch_fallible, fetch_stream,
 };
 
 impl HttpRequest {
@@ -15,6 +16,11 @@ impl HttpRequest {
         fetch(self)
     }
 
+    /// Like [`fetch`], except it doesn't treat any status codes as errors.
+    pub fn fetch_fallible(&self) -> Result<HttpResponseWithStatus, String> {
+        fetch_fallible(self)
+    }
+
     /// Executes the [`HttpRequest`] with [`fetch_stream`].
     pub fn fetch_stream(&self) -> Result<HttpResponseStream, String> {
         fetch_stream(self)

crates/extension_api/wit/since_v0.8.0/http-client.wit 🔗

@@ -51,9 +51,26 @@ interface http-client {
         body: list<u8>,
     }
 
+    /// An HTTP response that includes the status code.
+    ///
+    /// Used by `fetch-fallible` which returns responses for all status codes
+    /// rather than treating some status codes as errors.
+    record http-response-with-status {
+        /// The HTTP status code.
+        status: u16,
+        /// The response headers.
+        headers: list<tuple<string, string>>,
+        /// The response body.
+        body: list<u8>,
+    }
+
     /// Performs an HTTP request and returns the response.
+    /// Returns an error if the response status is 4xx or 5xx.
     fetch: func(req: http-request) -> result<http-response, string>;
 
+    /// Performs an HTTP request and returns the response regardless of its status code.
+    fetch-fallible: func(req: http-request) -> result<http-response-with-status, string>;
+
     /// An HTTP response stream.
     resource http-response-stream {
         /// Retrieves the next chunk of data from the response stream.

crates/extension_api/wit/since_v0.8.0/llm-provider.wit 🔗

@@ -1,4 +1,6 @@
 interface llm-provider {
+    use http-client.{http-request, http-response-with-status};
+
     /// Information about a language model provider.
     record provider-info {
         /// Unique identifier for the provider (e.g. "my-extension.my-provider").
@@ -271,28 +273,6 @@ interface llm-provider {
         port: u32,
     }
 
-    /// A generic HTTP request for OAuth token exchange.
-    record oauth-http-request {
-        /// The URL to request.
-        url: string,
-        /// HTTP method (e.g., "POST", "GET").
-        method: string,
-        /// Request headers as key-value pairs.
-        headers: list<tuple<string, string>>,
-        /// Request body as a string (for form-encoded or JSON bodies).
-        body: string,
-    }
-
-    /// Response from an OAuth HTTP request.
-    record oauth-http-response {
-        /// HTTP status code.
-        status: u16,
-        /// Response headers as key-value pairs.
-        headers: list<tuple<string, string>>,
-        /// Response body as a string.
-        body: string,
-    }
-
     /// Get a stored credential for this provider.
     get-credential: func(provider-id: string) -> option<string>;
 
@@ -316,14 +296,15 @@ interface llm-provider {
     /// The extension is responsible for:
     /// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
     /// - Parsing the callback URL to extract the authorization code
-    /// - Exchanging the code for tokens using oauth-http-request
+    /// - Exchanging the code for tokens using fetch-fallible from http-client
     oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
 
     /// Make an HTTP request for OAuth token exchange.
     ///
-    /// This is a simple HTTP client for OAuth flows, allowing the extension
-    /// to handle token exchange with full control over serialization.
-    oauth-send-http-request: func(request: oauth-http-request) -> result<oauth-http-response, string>;
+    /// This is a convenience wrapper around http-client's fetch-fallible for OAuth flows.
+    /// Unlike the standard fetch, this does not treat non-2xx responses as errors,
+    /// allowing proper handling of OAuth error responses.
+    oauth-send-http-request: func(request: http-request) -> result<http-response-with-status, string>;
 
     /// Open a URL in the user's default browser.
     ///

crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs 🔗

@@ -618,6 +618,19 @@ impl http_client::Host for WasmState {
         .to_wasmtime_result()
     }
 
+    async fn fetch_fallible(
+        &mut self,
+        request: http_client::HttpRequest,
+    ) -> wasmtime::Result<Result<http_client::HttpResponseWithStatus, String>> {
+        maybe!(async {
+            let request = convert_request(&request)?;
+            let mut response = self.host.http_client.send(request).await?;
+            convert_response_with_status(&mut response).await
+        })
+        .await
+        .to_wasmtime_result()
+    }
+
     async fn fetch_stream(
         &mut self,
         request: http_client::HttpRequest,
@@ -721,6 +734,26 @@ async fn convert_response(
     Ok(extension_response)
 }
 
+async fn convert_response_with_status(
+    response: &mut ::http_client::Response<AsyncBody>,
+) -> anyhow::Result<http_client::HttpResponseWithStatus> {
+    let status = response.status().as_u16();
+    let headers: Vec<(String, String)> = response
+        .headers()
+        .iter()
+        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
+        .collect();
+
+    let mut body = Vec::new();
+    response.body_mut().read_to_end(&mut body).await?;
+
+    Ok(http_client::HttpResponseWithStatus {
+        status,
+        headers,
+        body,
+    })
+}
+
 impl nodejs::Host for WasmState {
     async fn node_binary_path(&mut self) -> wasmtime::Result<Result<String, String>> {
         self.host
@@ -1376,70 +1409,12 @@ impl llm_provider::Host for WasmState {
 
     async fn oauth_send_http_request(
         &mut self,
-        request: llm_provider::OauthHttpRequest,
-    ) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
-        let http_client = self.host.http_client.clone();
-
-        self.on_main_thread(move |_cx| {
-            async move {
-                let method = match request.method.to_uppercase().as_str() {
-                    "GET" => ::http_client::Method::GET,
-                    "POST" => ::http_client::Method::POST,
-                    "PUT" => ::http_client::Method::PUT,
-                    "DELETE" => ::http_client::Method::DELETE,
-                    "PATCH" => ::http_client::Method::PATCH,
-                    _ => {
-                        return Err(anyhow::anyhow!(
-                            "Unsupported HTTP method: {}",
-                            request.method
-                        ));
-                    }
-                };
-
-                let mut builder = ::http_client::Request::builder()
-                    .method(method)
-                    .uri(&request.url);
-
-                for (key, value) in &request.headers {
-                    builder = builder.header(key.as_str(), value.as_str());
-                }
-
-                let body = if request.body.is_empty() {
-                    AsyncBody::empty()
-                } else {
-                    AsyncBody::from(request.body.into_bytes())
-                };
-
-                let http_request = builder
-                    .body(body)
-                    .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
-
-                let mut response = http_client
-                    .send(http_request)
-                    .await
-                    .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
-
-                let status = response.status().as_u16();
-                let headers: Vec<(String, String)> = response
-                    .headers()
-                    .iter()
-                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
-                    .collect();
-
-                let mut body_bytes = Vec::new();
-                futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
-                    .await
-                    .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
-
-                let body = String::from_utf8_lossy(&body_bytes).to_string();
-
-                Ok(llm_provider::OauthHttpResponse {
-                    status,
-                    headers,
-                    body,
-                })
-            }
-            .boxed_local()
+        request: http_client::HttpRequest,
+    ) -> wasmtime::Result<Result<http_client::HttpResponseWithStatus, String>> {
+        maybe!(async {
+            let request = convert_request(&request)?;
+            let mut response = self.host.http_client.send(request).await?;
+            convert_response_with_status(&mut response).await
         })
         .await
         .to_wasmtime_result()

extensions/copilot-chat/src/copilot_chat.rs 🔗

@@ -456,9 +456,9 @@ impl zed::Extension for CopilotChatProvider {
         _provider_id: &str,
     ) -> Result<LlmDeviceFlowPromptInfo, String> {
         // Step 1: Request device and user verification codes
-        let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
+        let device_code_response = llm_oauth_send_http_request(&HttpRequest {
+            method: HttpMethod::Post,
             url: GITHUB_DEVICE_CODE_URL.to_string(),
-            method: "POST".to_string(),
             headers: vec![
                 ("Accept".to_string(), "application/json".to_string()),
                 (
@@ -466,7 +466,10 @@ impl zed::Extension for CopilotChatProvider {
                     "application/x-www-form-urlencoded".to_string(),
                 ),
             ],
-            body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
+            body: Some(
+                format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID).into_bytes(),
+            ),
+            redirect_policy: RedirectPolicy::NoFollow,
         })?;
 
         if device_code_response.status != 200 {
@@ -487,7 +490,7 @@ impl zed::Extension for CopilotChatProvider {
             interval: u64,
         }
 
-        let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
+        let device_info: DeviceCodeResponse = serde_json::from_slice(&device_code_response.body)
             .map_err(|e| format!("Failed to parse device code response: {}", e))?;
 
         // Store device flow state for polling
@@ -534,9 +537,9 @@ impl zed::Extension for CopilotChatProvider {
         for _ in 0..max_attempts {
             thread::sleep(poll_interval);
 
-            let token_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
+            let token_response = llm_oauth_send_http_request(&HttpRequest {
+                method: HttpMethod::Post,
                 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
-                method: "POST".to_string(),
                 headers: vec![
                     ("Accept".to_string(), "application/json".to_string()),
                     (
@@ -544,10 +547,14 @@ impl zed::Extension for CopilotChatProvider {
                         "application/x-www-form-urlencoded".to_string(),
                     ),
                 ],
-                body: format!(
-                    "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
-                    GITHUB_COPILOT_CLIENT_ID, state.device_code
+                body: Some(
+                    format!(
+                        "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
+                        GITHUB_COPILOT_CLIENT_ID, state.device_code
+                    )
+                    .into_bytes(),
                 ),
+                redirect_policy: RedirectPolicy::NoFollow,
             })?;
 
             #[derive(Deserialize)]
@@ -557,7 +564,7 @@ impl zed::Extension for CopilotChatProvider {
                 error_description: Option<String>,
             }
 
-            let token_json: TokenResponse = serde_json::from_str(&token_response.body)
+            let token_json: TokenResponse = serde_json::from_slice(&token_response.body)
                 .map_err(|e| format!("Failed to parse token response: {}", e))?;
 
             if let Some(access_token) = token_json.access_token {