diff --git a/Cargo.lock b/Cargo.lock index 8df6dd15d3c0a695c2f6b16155fad7d900b79155..bb860f669acc5aea09ee382cbc97dec2e102b196 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21791,6 +21791,7 @@ dependencies = [ "pretty_assertions", "project", "release_channel", + "serde", "serde_json", "settings", "thiserror 2.0.12", diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 76bf0b905dbdc827f38aa37a95edc0e3b9e834eb..056cee4e346e34b5689a0dfe3278c880b7297986 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -6,13 +6,12 @@ pub use anyhow::{Result, anyhow}; pub use async_body::{AsyncBody, Inner}; use derive_more::Deref; use http::HeaderValue; -pub use http::{self, Method, Request, Response, StatusCode, Uri}; +pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder}; use futures::{ FutureExt as _, future::{self, BoxFuture}, }; -use http::request::Builder; use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index bce7e5987ccec635b335110a3a38298040c68e72..2342d062979f459c23441d8a57c0a640f5ce41b2 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -28,6 +28,7 @@ language_model.workspace = true log.workspace = true project.workspace = true release_channel.workspace = true +serde.workspace = true serde_json.workspace = true thiserror.workspace = true util.workspace = true diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index 9611d48023d84a91e477a51ff863b9ca6f0566a8..a0dcd83b88142a5746c0b3c7d82bc7a64965edab 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -13,6 +13,12 @@ use uuid::Uuid; #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct EditPredictionId(Uuid); +impl Into for EditPredictionId { + fn into(self) -> Uuid { + self.0 + } +} + impl From for gpui::ElementId { fn from(value: EditPredictionId) -> Self { gpui::ElementId::Uuid(value.0) diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index db637208aa88e8e3ebe4b30dc3d5639497cd0ac0..3c0dd75cc23a6a7b18a0fba19d0eab0a4833ba9c 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -179,8 +179,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn accept(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, _cx| { - zeta.accept_current_prediction(&self.project); + self.zeta.update(cx, |zeta, cx| { + zeta.accept_current_prediction(&self.project, cx); }); self.pending_predictions.clear(); } diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 5cb163ceed135f0df7b1277908377a931b02aa7e..2e07122d26e55ec7bd716a18d0a2cee210c4a584 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -3,7 +3,8 @@ use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; use cloud_llm_client::{ - EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt}; use edit_prediction_context::{ @@ -12,7 +13,7 @@ use edit_prediction_context::{ }; use futures::AsyncReadExt as _; use futures::channel::{mpsc, oneshot}; -use gpui::http_client::Method; +use gpui::http_client::{AsyncBody, Method}; use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, @@ -22,6 +23,7 @@ use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; +use serde::de::DeserializeOwned; use std::collections::{HashMap, VecDeque, hash_map}; use std::path::Path; use std::str::FromStr as _; @@ -391,11 +393,46 @@ impl Zeta { } } - fn accept_current_prediction(&mut self, project: &Entity) { - if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { - project_state.current_prediction.take(); + fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + let Some(prediction) = project_state.current_prediction.take() else { + return; }; - // TODO report accepted + let request_id = prediction.prediction.id.into(); + + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + cx.spawn(async move |this, cx| { + let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/accept", &[])? + }; + + let response = cx + .background_spawn(Self::send_api_request::<()>( + move |builder| { + let req = builder.uri(url.as_ref()).body( + serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(), + ); + Ok(req?) + }, + client, + llm_token, + app_version, + )) + .await; + + Self::handle_api_response(&this, response, cx)?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } fn discard_current_prediction(&mut self, project: &Entity) { @@ -545,7 +582,7 @@ impl Zeta { &options.context, index_state.as_deref(), ) else { - return Ok(None); + return Ok((None, None)); }; let retrieval_time = chrono::Utc::now() - before_retrieval; @@ -607,7 +644,8 @@ impl Zeta { anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") } - let response = Self::perform_request(client, llm_token, app_version, request).await; + let response = + Self::send_prediction_request(client, llm_token, app_version, request).await; if let Some(debug_response_tx) = debug_response_tx { debug_response_tx @@ -620,7 +658,7 @@ impl Zeta { .ok(); } - anyhow::Ok(Some(response?)) + response.map(|(res, usage)| (Some(res), usage)) } }); @@ -629,60 +667,18 @@ impl Zeta { cx.spawn({ let project = project.clone(); async move |this, cx| { - match request_task.await { - Ok(Some((response, usage))) => { - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - - let prediction = EditPrediction::from_response( - response, &snapshot, &buffer, &project, cx, - ) - .await; - - // TODO telemetry: duration, etc - Ok(prediction) - } - Ok(None) => Ok(None), - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button( - "Update Zed", - "https://zed.dev/releases", - ) - }) - }, - ); - }) - .ok(); - } + let Some(response) = Self::handle_api_response(&this, request_task.await, cx)? + else { + return Ok(None); + }; - Err(err) - } - } + // TODO telemetry: duration, etc + Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await) } }) } - async fn perform_request( + async fn send_prediction_request( client: Arc, llm_token: LlmApiToken, app_version: SemanticVersion, @@ -691,27 +687,94 @@ impl Zeta { predict_edits_v3::PredictEditsResponse, Option, )> { + let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/v3", &[])? + }; + + Self::send_api_request( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&request)?.into()); + Ok(req?) + }, + client, + llm_token, + app_version, + ) + .await + } + + fn handle_api_response( + this: &WeakEntity, + response: Result<(T, Option)>, + cx: &mut gpui::AsyncApp, + ) -> Result { + match response { + Ok((data, usage)) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + Ok(data) + } + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }) + .ok(); + } + Err(err) + } + } + } + + async fn send_api_request( + build: impl Fn(http_client::http::request::Builder) -> Result>, + client: Arc, + llm_token: LlmApiToken, + app_version: SemanticVersion, + ) -> Result<(Res, Option)> + where + Res: DeserializeOwned, + { let http_client = client.http_client(); let mut token = llm_token.acquire(&client).await?; let mut did_retry = false; loop { let request_builder = http_client::Request::builder().method(Method::POST); - let request_builder = - if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { - request_builder.uri(predict_edits_url) - } else { - request_builder.uri( - http_client - .build_zed_llm_url("/predict_edits/v3", &[])? - .as_ref(), - ) - }; - let request = request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - .body(serde_json::to_string(&request)?.into())?; + + let request = build( + request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), + )?; let mut response = http_client.send(request).await?; @@ -746,7 +809,7 @@ impl Zeta { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; anyhow::bail!( - "error predicting edits.\nStatus: {:?}\nBody: {}", + "Request failed with status: {:?}\nBody: {}", response.status(), body );