Cargo.lock 🔗
@@ -21791,6 +21791,7 @@ dependencies = [
"pretty_assertions",
"project",
"release_channel",
+ "serde",
"serde_json",
"settings",
"thiserror 2.0.12",
Agus Zubiaga created
Release Notes:
- N/A
Cargo.lock | 1
crates/http_client/src/http_client.rs | 3
crates/zeta2/Cargo.toml | 1
crates/zeta2/src/prediction.rs | 6
crates/zeta2/src/provider.rs | 4
crates/zeta2/src/zeta2.rs | 211 ++++++++++++++++++----------
6 files changed, 148 insertions(+), 78 deletions(-)
@@ -21791,6 +21791,7 @@ dependencies = [
"pretty_assertions",
"project",
"release_channel",
+ "serde",
"serde_json",
"settings",
"thiserror 2.0.12",
@@ -6,13 +6,12 @@ pub use anyhow::{Result, anyhow};
pub use async_body::{AsyncBody, Inner};
use derive_more::Deref;
use http::HeaderValue;
-pub use http::{self, Method, Request, Response, StatusCode, Uri};
+pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
use futures::{
FutureExt as _,
future::{self, BoxFuture},
};
-use http::request::Builder;
use parking_lot::Mutex;
#[cfg(feature = "test-support")]
use std::fmt;
@@ -28,6 +28,7 @@ language_model.workspace = true
log.workspace = true
project.workspace = true
release_channel.workspace = true
+serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
util.workspace = true
@@ -13,6 +13,12 @@ use uuid::Uuid;
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(Uuid);
+impl Into<Uuid> for EditPredictionId {
+ fn into(self) -> Uuid {
+ self.0
+ }
+}
+
impl From<EditPredictionId> for gpui::ElementId {
fn from(value: EditPredictionId) -> Self {
gpui::ElementId::Uuid(value.0)
@@ -179,8 +179,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
fn accept(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.accept_current_prediction(&self.project);
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.accept_current_prediction(&self.project, cx);
});
self.pending_predictions.clear();
}
@@ -3,7 +3,8 @@ use chrono::TimeDelta;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
use cloud_llm_client::{
- EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+ ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
use edit_prediction_context::{
@@ -12,7 +13,7 @@ use edit_prediction_context::{
};
use futures::AsyncReadExt as _;
use futures::channel::{mpsc, oneshot};
-use gpui::http_client::Method;
+use gpui::http_client::{AsyncBody, Method};
use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
@@ -22,6 +23,7 @@ use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
+use serde::de::DeserializeOwned;
use std::collections::{HashMap, VecDeque, hash_map};
use std::path::Path;
use std::str::FromStr as _;
@@ -391,11 +393,46 @@ impl Zeta {
}
}
- fn accept_current_prediction(&mut self, project: &Entity<Project>) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- project_state.current_prediction.take();
+ fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ let Some(prediction) = project_state.current_prediction.take() else {
+ return;
};
- // TODO report accepted
+ let request_id = prediction.prediction.id.into();
+
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ cx.spawn(async move |this, cx| {
+ let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ };
+
+ let response = cx
+ .background_spawn(Self::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(
+ serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
+ );
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ ))
+ .await;
+
+ Self::handle_api_response(&this, response, cx)?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
}
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
@@ -545,7 +582,7 @@ impl Zeta {
&options.context,
index_state.as_deref(),
) else {
- return Ok(None);
+ return Ok((None, None));
};
let retrieval_time = chrono::Utc::now() - before_retrieval;
@@ -607,7 +644,8 @@ impl Zeta {
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
- let response = Self::perform_request(client, llm_token, app_version, request).await;
+ let response =
+ Self::send_prediction_request(client, llm_token, app_version, request).await;
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
@@ -620,7 +658,7 @@ impl Zeta {
.ok();
}
- anyhow::Ok(Some(response?))
+ response.map(|(res, usage)| (Some(res), usage))
}
});
@@ -629,60 +667,18 @@ impl Zeta {
cx.spawn({
let project = project.clone();
async move |this, cx| {
- match request_task.await {
- Ok(Some((response, usage))) => {
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
-
- let prediction = EditPrediction::from_response(
- response, &snapshot, &buffer, &project, cx,
- )
- .await;
-
- // TODO telemetry: duration, etc
- Ok(prediction)
- }
- Ok(None) => Ok(None),
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |this, _cx| {
- this.update_required = true;
- })
- .ok();
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button(
- "Update Zed",
- "https://zed.dev/releases",
- )
- })
- },
- );
- })
- .ok();
- }
+ let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
- Err(err)
- }
- }
+ // TODO telemetry: duration, etc
+ Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
}
})
}
- async fn perform_request(
+ async fn send_prediction_request(
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: SemanticVersion,
@@ -691,27 +687,94 @@ impl Zeta {
predict_edits_v3::PredictEditsResponse,
Option<EditPredictionUsage>,
)> {
+ let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/v3", &[])?
+ };
+
+ Self::send_api_request(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&request)?.into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ )
+ .await
+ }
+
+ fn handle_api_response<T>(
+ this: &WeakEntity<Self>,
+ response: Result<(T, Option<EditPredictionUsage>)>,
+ cx: &mut gpui::AsyncApp,
+ ) -> Result<T> {
+ match response {
+ Ok((data, usage)) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+ Ok(data)
+ }
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ })
+ .ok();
+ }
+ Err(err)
+ }
+ }
+ }
+
+ async fn send_api_request<Res>(
+ build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: SemanticVersion,
+ ) -> Result<(Res, Option<EditPredictionUsage>)>
+ where
+ Res: DeserializeOwned,
+ {
let http_client = client.http_client();
let mut token = llm_token.acquire(&client).await?;
let mut did_retry = false;
loop {
let request_builder = http_client::Request::builder().method(Method::POST);
- let request_builder =
- if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
- request_builder.uri(predict_edits_url)
- } else {
- request_builder.uri(
- http_client
- .build_zed_llm_url("/predict_edits/v3", &[])?
- .as_ref(),
- )
- };
- let request = request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- .body(serde_json::to_string(&request)?.into())?;
+
+ let request = build(
+ request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
+ )?;
let mut response = http_client.send(request).await?;
@@ -746,7 +809,7 @@ impl Zeta {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
- "error predicting edits.\nStatus: {:?}\nBody: {}",
+ "Request failed with status: {:?}\nBody: {}",
response.status(),
body
);