diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 14acdfd66597eca040102e429bf0a6def73b6ec6..c1bd322b5d9da1ef941a8346f5d243aa2a5eec83 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -34,8 +34,7 @@ pub use wit::{ CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, ImageData as LlmImageData, MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, - ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest, - OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig, + ModelInfo as LlmModelInfo, OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, diff --git a/crates/extension_api/src/http_client.rs b/crates/extension_api/src/http_client.rs index 9e30da8db4856635460c7636f0698fef34266797..c5be8874bc2ce077be99bff9c1a4b5429370c8d9 100644 --- a/crates/extension_api/src/http_client.rs +++ b/crates/extension_api/src/http_client.rs @@ -1,7 +1,8 @@ //! An HTTP client. pub use crate::wit::zed::extension::http_client::{ - HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, RedirectPolicy, fetch, fetch_stream, + HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, HttpResponseWithStatus, + RedirectPolicy, fetch, fetch_fallible, fetch_stream, }; impl HttpRequest { @@ -15,6 +16,11 @@ impl HttpRequest { fetch(self) } + /// Like [`fetch`], except it doesn't treat any status codes as errors. + pub fn fetch_fallible(&self) -> Result { + fetch_fallible(self) + } + /// Executes the [`HttpRequest`] with [`fetch_stream`]. pub fn fetch_stream(&self) -> Result { fetch_stream(self) diff --git a/crates/extension_api/wit/since_v0.8.0/http-client.wit b/crates/extension_api/wit/since_v0.8.0/http-client.wit index bb0206c17a52d4d20b99f445dca4ac606e0485f7..422ca8cd843985ccbdd7b3e663db2f0d0141f544 100644 --- a/crates/extension_api/wit/since_v0.8.0/http-client.wit +++ b/crates/extension_api/wit/since_v0.8.0/http-client.wit @@ -51,9 +51,26 @@ interface http-client { body: list, } + /// An HTTP response that includes the status code. + /// + /// Used by `fetch-fallible` which returns responses for all status codes + /// rather than treating some status codes as errors. + record http-response-with-status { + /// The HTTP status code. + status: u16, + /// The response headers. + headers: list>, + /// The response body. + body: list, + } + /// Performs an HTTP request and returns the response. + /// Returns an error if the response status is 4xx or 5xx. fetch: func(req: http-request) -> result; + /// Performs an HTTP request and returns the response regardless of its status code. + fetch-fallible: func(req: http-request) -> result; + /// An HTTP response stream. resource http-response-stream { /// Retrieves the next chunk of data from the response stream. diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit index 04d3861be0cc483978e11108e9bfca921f07d37c..1c9ce7d8ca8f22624135549594b6adc45dccda68 100644 --- a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -1,4 +1,6 @@ interface llm-provider { + use http-client.{http-request, http-response-with-status}; + /// Information about a language model provider. record provider-info { /// Unique identifier for the provider (e.g. "my-extension.my-provider"). @@ -271,28 +273,6 @@ interface llm-provider { port: u32, } - /// A generic HTTP request for OAuth token exchange. - record oauth-http-request { - /// The URL to request. - url: string, - /// HTTP method (e.g., "POST", "GET"). - method: string, - /// Request headers as key-value pairs. - headers: list>, - /// Request body as a string (for form-encoded or JSON bodies). - body: string, - } - - /// Response from an OAuth HTTP request. - record oauth-http-response { - /// HTTP status code. - status: u16, - /// Response headers as key-value pairs. - headers: list>, - /// Response body as a string. - body: string, - } - /// Get a stored credential for this provider. get-credential: func(provider-id: string) -> option; @@ -316,14 +296,15 @@ interface llm-provider { /// The extension is responsible for: /// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc. /// - Parsing the callback URL to extract the authorization code - /// - Exchanging the code for tokens using oauth-http-request + /// - Exchanging the code for tokens using fetch-fallible from http-client oauth-start-web-auth: func(config: oauth-web-auth-config) -> result; /// Make an HTTP request for OAuth token exchange. /// - /// This is a simple HTTP client for OAuth flows, allowing the extension - /// to handle token exchange with full control over serialization. - oauth-send-http-request: func(request: oauth-http-request) -> result; + /// This is a convenience wrapper around http-client's fetch-fallible for OAuth flows. + /// Unlike the standard fetch, this does not treat non-2xx responses as errors, + /// allowing proper handling of OAuth error responses. + oauth-send-http-request: func(request: http-request) -> result; /// Open a URL in the user's default browser. /// diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index bc1c6832814ea034d8c290f1786af0cc984505a6..d07a064a5ca76d6cc1db7980499bb6acdefcb843 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -618,6 +618,19 @@ impl http_client::Host for WasmState { .to_wasmtime_result() } + async fn fetch_fallible( + &mut self, + request: http_client::HttpRequest, + ) -> wasmtime::Result> { + maybe!(async { + let request = convert_request(&request)?; + let mut response = self.host.http_client.send(request).await?; + convert_response_with_status(&mut response).await + }) + .await + .to_wasmtime_result() + } + async fn fetch_stream( &mut self, request: http_client::HttpRequest, @@ -721,6 +734,26 @@ async fn convert_response( Ok(extension_response) } +async fn convert_response_with_status( + response: &mut ::http_client::Response, +) -> anyhow::Result { + let status = response.status().as_u16(); + let headers: Vec<(String, String)> = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + Ok(http_client::HttpResponseWithStatus { + status, + headers, + body, + }) +} + impl nodejs::Host for WasmState { async fn node_binary_path(&mut self) -> wasmtime::Result> { self.host @@ -1376,70 +1409,12 @@ impl llm_provider::Host for WasmState { async fn oauth_send_http_request( &mut self, - request: llm_provider::OauthHttpRequest, - ) -> wasmtime::Result> { - let http_client = self.host.http_client.clone(); - - self.on_main_thread(move |_cx| { - async move { - let method = match request.method.to_uppercase().as_str() { - "GET" => ::http_client::Method::GET, - "POST" => ::http_client::Method::POST, - "PUT" => ::http_client::Method::PUT, - "DELETE" => ::http_client::Method::DELETE, - "PATCH" => ::http_client::Method::PATCH, - _ => { - return Err(anyhow::anyhow!( - "Unsupported HTTP method: {}", - request.method - )); - } - }; - - let mut builder = ::http_client::Request::builder() - .method(method) - .uri(&request.url); - - for (key, value) in &request.headers { - builder = builder.header(key.as_str(), value.as_str()); - } - - let body = if request.body.is_empty() { - AsyncBody::empty() - } else { - AsyncBody::from(request.body.into_bytes()) - }; - - let http_request = builder - .body(body) - .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?; - - let mut response = http_client - .send(http_request) - .await - .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?; - - let status = response.status().as_u16(); - let headers: Vec<(String, String)> = response - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); - - let mut body_bytes = Vec::new(); - futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes) - .await - .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?; - - let body = String::from_utf8_lossy(&body_bytes).to_string(); - - Ok(llm_provider::OauthHttpResponse { - status, - headers, - body, - }) - } - .boxed_local() + request: http_client::HttpRequest, + ) -> wasmtime::Result> { + maybe!(async { + let request = convert_request(&request)?; + let mut response = self.host.http_client.send(request).await?; + convert_response_with_status(&mut response).await }) .await .to_wasmtime_result() diff --git a/extensions/copilot-chat/src/copilot_chat.rs b/extensions/copilot-chat/src/copilot_chat.rs index 866e86098badb560f76a3bbf04ec80a59e0c2391..d286b2f776318b81572a41b724d646dcc21d0230 100644 --- a/extensions/copilot-chat/src/copilot_chat.rs +++ b/extensions/copilot-chat/src/copilot_chat.rs @@ -456,9 +456,9 @@ impl zed::Extension for CopilotChatProvider { _provider_id: &str, ) -> Result { // Step 1: Request device and user verification codes - let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest { + let device_code_response = llm_oauth_send_http_request(&HttpRequest { + method: HttpMethod::Post, url: GITHUB_DEVICE_CODE_URL.to_string(), - method: "POST".to_string(), headers: vec![ ("Accept".to_string(), "application/json".to_string()), ( @@ -466,7 +466,10 @@ impl zed::Extension for CopilotChatProvider { "application/x-www-form-urlencoded".to_string(), ), ], - body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID), + body: Some( + format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID).into_bytes(), + ), + redirect_policy: RedirectPolicy::NoFollow, })?; if device_code_response.status != 200 { @@ -487,7 +490,7 @@ impl zed::Extension for CopilotChatProvider { interval: u64, } - let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body) + let device_info: DeviceCodeResponse = serde_json::from_slice(&device_code_response.body) .map_err(|e| format!("Failed to parse device code response: {}", e))?; // Store device flow state for polling @@ -534,9 +537,9 @@ impl zed::Extension for CopilotChatProvider { for _ in 0..max_attempts { thread::sleep(poll_interval); - let token_response = llm_oauth_send_http_request(&LlmOauthHttpRequest { + let token_response = llm_oauth_send_http_request(&HttpRequest { + method: HttpMethod::Post, url: GITHUB_ACCESS_TOKEN_URL.to_string(), - method: "POST".to_string(), headers: vec![ ("Accept".to_string(), "application/json".to_string()), ( @@ -544,10 +547,14 @@ impl zed::Extension for CopilotChatProvider { "application/x-www-form-urlencoded".to_string(), ), ], - body: format!( - "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", - GITHUB_COPILOT_CLIENT_ID, state.device_code + body: Some( + format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + GITHUB_COPILOT_CLIENT_ID, state.device_code + ) + .into_bytes(), ), + redirect_policy: RedirectPolicy::NoFollow, })?; #[derive(Deserialize)] @@ -557,7 +564,7 @@ impl zed::Extension for CopilotChatProvider { error_description: Option, } - let token_json: TokenResponse = serde_json::from_str(&token_response.body) + let token_json: TokenResponse = serde_json::from_slice(&token_response.body) .map_err(|e| format!("Failed to parse token response: {}", e))?; if let Some(access_token) = token_json.access_token {