diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 7fd96fcef0e8fd764bbcaa8ab59a9666095f9db9..53b2b16a6a7c9447face6daa199bc4b2125445b9 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -9,7 +9,7 @@ use futures::AsyncReadExt as _; use gpui::{App, Task}; use gpui_tokio::Tokio; use http_client::http::request; -use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; +use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode}; use parking_lot::RwLock; use yawc::WebSocket; @@ -119,15 +119,16 @@ impl CloudApiClient { &self, system_id: Option, ) -> Result { - let mut request_builder = Request::builder().method(Method::POST).uri( - self.http_client - .build_zed_cloud_url("/client/llm_tokens", &[])? - .as_ref(), - ); - - if let Some(system_id) = system_id { - request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); - } + let request_builder = Request::builder() + .method(Method::POST) + .uri( + self.http_client + .build_zed_cloud_url("/client/llm_tokens", &[])? + .as_ref(), + ) + .when_some(system_id, |builder, system_id| { + builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id) + }); let request = self.build_request(request_builder, AsyncBody::default())?; diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index ccd8f09613eec54f2d30b619f142d111bf2a3497..a6758ce53c0aa18d04dcd376c2e0afb93add6ab5 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -10,6 +10,7 @@ use fs::Fs; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use gpui::WeakEntity; use gpui::{App, AsyncApp, Global, prelude::*}; +use http_client::HttpRequestExt; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use itertools::Itertools; use paths::home_dir; @@ -741,7 +742,7 @@ async fn stream_completion( let request_initiator = if is_user_initiated { "user" } else { "agent" }; - let mut request_builder = HttpRequest::builder() + let request_builder = HttpRequest::builder() .method(Method::POST) .uri(completion_url.as_ref()) .header( @@ -754,12 +755,10 @@ async fn stream_completion( .header("Authorization", format!("Bearer {}", api_key)) .header("Content-Type", "application/json") .header("Copilot-Integration-Id", "vscode-chat") - .header("X-Initiator", request_initiator); - - if is_vision_request { - request_builder = - request_builder.header("Copilot-Vision-Request", is_vision_request.to_string()); - } + .header("X-Initiator", request_initiator) + .when(is_vision_request, |builder| { + builder.header("Copilot-Vision-Request", is_vision_request.to_string()) + }); let is_streaming = request.stream; diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 1429b7bf941fab5b1b508b977e898b8e153942d1..0bbb7ce037fcda014b346556202256b99e832529 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -28,6 +28,25 @@ pub enum RedirectPolicy { pub struct FollowRedirects(pub bool); pub trait HttpRequestExt { + /// Conditionally modify self with the given closure. + fn when(self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self + where + Self: Sized, + { + if condition { then(self) } else { self } + } + + /// Conditionally unwrap and modify self with the given closure, if the given option is Some. + fn when_some(self, option: Option, then: impl FnOnce(Self, T) -> Self) -> Self + where + Self: Sized, + { + match option { + Some(value) => then(self, value), + None => self, + } + } + /// Whether or not to follow redirects fn follow_redirects(self, follow: RedirectPolicy) -> Self; } @@ -48,12 +67,12 @@ pub trait HttpClient: 'static + Send + Sync { req: http::Request, ) -> BoxFuture<'static, anyhow::Result>>; - fn get<'a>( - &'a self, + fn get( + &self, uri: &str, body: AsyncBody, follow_redirects: bool, - ) -> BoxFuture<'a, anyhow::Result>> { + ) -> BoxFuture<'static, anyhow::Result>> { let request = Builder::new() .uri(uri) .follow_redirects(if follow_redirects { @@ -64,16 +83,16 @@ pub trait HttpClient: 'static + Send + Sync { .body(body); match request { - Ok(request) => Box::pin(async move { self.send(request).await }), + Ok(request) => self.send(request), Err(e) => Box::pin(async move { Err(e.into()) }), } } - fn post_json<'a>( - &'a self, + fn post_json( + &self, uri: &str, body: AsyncBody, - ) -> BoxFuture<'a, anyhow::Result>> { + ) -> BoxFuture<'static, anyhow::Result>> { let request = Builder::new() .uri(uri) .method(Method::POST) @@ -81,7 +100,7 @@ pub trait HttpClient: 'static + Send + Sync { .body(body); match request { - Ok(request) => Box::pin(async move { self.send(request).await }), + Ok(request) => self.send(request), Err(e) => Box::pin(async move { Err(e.into()) }), } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index c62a6989501a71e444b07992bff0cbe1a1bbd6d6..40958af77535b34e2d68afde49e7ab97f07f911f 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -19,7 +19,7 @@ use gpui::{ AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task, }; use http_client::http::{HeaderMap, HeaderValue}; -use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; +use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, @@ -391,20 +391,17 @@ impl CloudLanguageModel { let mut refreshed_token = false; loop { - let request_builder = http_client::Request::builder() + let request = http_client::Request::builder() .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()); - let request_builder = if let Some(app_version) = app_version { - request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - } else { - request_builder - }; - - let request = request_builder + .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) + .when_some(app_version, |builder, app_version| { + builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + }) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") .body(serde_json::to_string(&body)?.into())?; + let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index dced37e0fc1e19e61bba5e14010812f08fe3a1e5..48124f9625bf28a646ec4e9dc194bb1dd0df4c57 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; +use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; pub use settings::KeepAlive; @@ -261,16 +261,15 @@ pub async fn stream_chat_completion( request: ChatRequest, ) -> Result>> { let uri = format!("{api_url}/api/chat"); - let mut request_builder = http::Request::builder() + let request = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json"); - - if let Some(api_key) = api_key { - request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) - } + .header("Content-Type", "application/json") + .when_some(api_key, |builder, api_key| { + builder.header("Authorization", format!("Bearer {api_key}")) + }) + .body(AsyncBody::from(serde_json::to_string(&request)?))?; - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -300,16 +299,14 @@ pub async fn get_models( _: Option, ) -> Result> { let uri = format!("{api_url}/api/tags"); - let mut request_builder = HttpRequest::builder() + let request = HttpRequest::builder() .method(Method::GET) .uri(uri) - .header("Accept", "application/json"); - - if let Some(api_key) = api_key { - request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")); - } - - let request = request_builder.body(AsyncBody::default())?; + .header("Accept", "application/json") + .when_some(api_key, |builder, api_key| { + builder.header("Authorization", format!("Bearer {api_key}")) + }) + .body(AsyncBody::default())?; let mut response = client.send(request).await?; @@ -335,18 +332,16 @@ pub async fn show_model( model: &str, ) -> Result { let uri = format!("{api_url}/api/show"); - let mut request_builder = HttpRequest::builder() + let request = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json"); - - if let Some(api_key) = api_key { - request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) - } - - let request = request_builder.body(AsyncBody::from( - serde_json::json!({ "model": model }).to_string(), - ))?; + .header("Content-Type", "application/json") + .when_some(api_key, |builder, api_key| { + builder.header("Authorization", format!("Bearer {api_key}")) + }) + .body(AsyncBody::from( + serde_json::json!({ "model": model }).to_string(), + ))?; let mut response = client.send(request).await?; let mut body = String::new();