Detailed changes
@@ -9,7 +9,7 @@ use futures::AsyncReadExt as _;
use gpui::{App, Task};
use gpui_tokio::Tokio;
use http_client::http::request;
-use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode};
+use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode};
use parking_lot::RwLock;
use yawc::WebSocket;
@@ -119,15 +119,16 @@ impl CloudApiClient {
&self,
system_id: Option<String>,
) -> Result<CreateLlmTokenResponse> {
- let mut request_builder = Request::builder().method(Method::POST).uri(
- self.http_client
- .build_zed_cloud_url("/client/llm_tokens", &[])?
- .as_ref(),
- );
-
- if let Some(system_id) = system_id {
- request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id);
- }
+ let request_builder = Request::builder()
+ .method(Method::POST)
+ .uri(
+ self.http_client
+ .build_zed_cloud_url("/client/llm_tokens", &[])?
+ .as_ref(),
+ )
+ .when_some(system_id, |builder, system_id| {
+ builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id)
+ });
let request = self.build_request(request_builder, AsyncBody::default())?;
@@ -10,6 +10,7 @@ use fs::Fs;
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use gpui::WeakEntity;
use gpui::{App, AsyncApp, Global, prelude::*};
+use http_client::HttpRequestExt;
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use itertools::Itertools;
use paths::home_dir;
@@ -741,7 +742,7 @@ async fn stream_completion(
let request_initiator = if is_user_initiated { "user" } else { "agent" };
- let mut request_builder = HttpRequest::builder()
+ let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(completion_url.as_ref())
.header(
@@ -754,12 +755,10 @@ async fn stream_completion(
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.header("Copilot-Integration-Id", "vscode-chat")
- .header("X-Initiator", request_initiator);
-
- if is_vision_request {
- request_builder =
- request_builder.header("Copilot-Vision-Request", is_vision_request.to_string());
- }
+ .header("X-Initiator", request_initiator)
+ .when(is_vision_request, |builder| {
+ builder.header("Copilot-Vision-Request", is_vision_request.to_string())
+ });
let is_streaming = request.stream;
@@ -28,6 +28,25 @@ pub enum RedirectPolicy {
pub struct FollowRedirects(pub bool);
pub trait HttpRequestExt {
+ /// Conditionally modify self with the given closure.
+ fn when(self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self
+ where
+ Self: Sized,
+ {
+ if condition { then(self) } else { self }
+ }
+
+ /// Conditionally unwrap and modify self with the given closure, if the given option is Some.
+ fn when_some<T>(self, option: Option<T>, then: impl FnOnce(Self, T) -> Self) -> Self
+ where
+ Self: Sized,
+ {
+ match option {
+ Some(value) => then(self, value),
+ None => self,
+ }
+ }
+
/// Whether or not to follow redirects
fn follow_redirects(self, follow: RedirectPolicy) -> Self;
}
@@ -48,12 +67,12 @@ pub trait HttpClient: 'static + Send + Sync {
req: http::Request<AsyncBody>,
) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>>;
- fn get<'a>(
- &'a self,
+ fn get(
+ &self,
uri: &str,
body: AsyncBody,
follow_redirects: bool,
- ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> {
+ ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
let request = Builder::new()
.uri(uri)
.follow_redirects(if follow_redirects {
@@ -64,16 +83,16 @@ pub trait HttpClient: 'static + Send + Sync {
.body(body);
match request {
- Ok(request) => Box::pin(async move { self.send(request).await }),
+ Ok(request) => self.send(request),
Err(e) => Box::pin(async move { Err(e.into()) }),
}
}
- fn post_json<'a>(
- &'a self,
+ fn post_json(
+ &self,
uri: &str,
body: AsyncBody,
- ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> {
+ ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
let request = Builder::new()
.uri(uri)
.method(Method::POST)
@@ -81,7 +100,7 @@ pub trait HttpClient: 'static + Send + Sync {
.body(body);
match request {
- Ok(request) => Box::pin(async move { self.send(request).await }),
+ Ok(request) => self.send(request),
Err(e) => Box::pin(async move { Err(e.into()) }),
}
}
@@ -19,7 +19,7 @@ use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
use http_client::http::{HeaderMap, HeaderValue};
-use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
@@ -391,20 +391,17 @@ impl CloudLanguageModel {
let mut refreshed_token = false;
loop {
- let request_builder = http_client::Request::builder()
+ let request = http_client::Request::builder()
.method(Method::POST)
- .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
- let request_builder = if let Some(app_version) = app_version {
- request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- } else {
- request_builder
- };
-
- let request = request_builder
+ .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
+ .when_some(app_version, |builder, app_version| {
+ builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+ })
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
+
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use settings::KeepAlive;
@@ -261,16 +261,15 @@ pub async fn stream_chat_completion(
request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
let uri = format!("{api_url}/api/chat");
- let mut request_builder = http::Request::builder()
+ let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
- .header("Content-Type", "application/json");
-
- if let Some(api_key) = api_key {
- request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
- }
+ .header("Content-Type", "application/json")
+ .when_some(api_key, |builder, api_key| {
+ builder.header("Authorization", format!("Bearer {api_key}"))
+ })
+ .body(AsyncBody::from(serde_json::to_string(&request)?))?;
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -300,16 +299,14 @@ pub async fn get_models(
_: Option<Duration>,
) -> Result<Vec<LocalModelListing>> {
let uri = format!("{api_url}/api/tags");
- let mut request_builder = HttpRequest::builder()
+ let request = HttpRequest::builder()
.method(Method::GET)
.uri(uri)
- .header("Accept", "application/json");
-
- if let Some(api_key) = api_key {
- request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
- }
-
- let request = request_builder.body(AsyncBody::default())?;
+ .header("Accept", "application/json")
+ .when_some(api_key, |builder, api_key| {
+ builder.header("Authorization", format!("Bearer {api_key}"))
+ })
+ .body(AsyncBody::default())?;
let mut response = client.send(request).await?;
@@ -335,18 +332,16 @@ pub async fn show_model(
model: &str,
) -> Result<ModelShow> {
let uri = format!("{api_url}/api/show");
- let mut request_builder = HttpRequest::builder()
+ let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
- .header("Content-Type", "application/json");
-
- if let Some(api_key) = api_key {
- request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
- }
-
- let request = request_builder.body(AsyncBody::from(
- serde_json::json!({ "model": model }).to_string(),
- ))?;
+ .header("Content-Type", "application/json")
+ .when_some(api_key, |builder, api_key| {
+ builder.header("Authorization", format!("Bearer {api_key}"))
+ })
+ .body(AsyncBody::from(
+ serde_json::json!({ "model": model }).to_string(),
+ ))?;
let mut response = client.send(request).await?;
let mut body = String::new();