diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 241e760887cdf0c4455f6769c79a813de0626028..15b5a4eda4f8473f48cc66d255598cc6c1d09f08 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -200,12 +200,31 @@ pub struct RejectEditPredictionsBody { pub rejections: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct EditPredictionRejection { pub request_id: String, + #[serde(default)] + pub reason: EditPredictionRejectReason, pub was_shown: bool, } +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +pub enum EditPredictionRejectReason { + /// New requests were triggered before this one completed + Canceled, + /// No edits returned + Empty, + /// Edits returned, but none remained after interpolation + InterpolatedEmpty, + /// The new prediction was preferred over the current one + Replaced, + /// The current prediction was preferred over the new one + CurrentPreferred, + /// The current prediction was discarded + #[default] + Discarded, +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CompletionMode { diff --git a/crates/zeta/src/prediction.rs b/crates/zeta/src/prediction.rs index 0125e739f335fc133cbff84dcd8b4c4bac3e6e7b..fd3241730030fe8bdd95e2cae9ee87b406ade735 100644 --- a/crates/zeta/src/prediction.rs +++ b/crates/zeta/src/prediction.rs @@ -5,6 +5,7 @@ use std::{ time::{Duration, Instant}, }; +use cloud_llm_client::EditPredictionRejectReason; use gpui::{AsyncApp, Entity, SharedString}; use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot}; use serde::Serialize; @@ -24,28 +25,13 @@ impl std::fmt::Display for EditPredictionId { } } -#[derive(Clone)] -pub struct EditPrediction { +/// A prediction response that was returned from the provider, whether it was ultimately valid or not. +pub struct EditPredictionResult { pub id: EditPredictionId, - pub edits: Arc<[(Range, Arc)]>, - pub snapshot: BufferSnapshot, - pub edit_preview: EditPreview, - // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction. - pub buffer: Entity, - pub buffer_snapshotted_at: Instant, - pub response_received_at: Instant, - pub inputs: EditPredictionInputs, -} - -#[derive(Debug, Clone, Serialize)] -pub struct EditPredictionInputs { - pub events: Vec>, - pub included_files: Vec, - pub cursor_point: cloud_llm_client::predict_edits_v3::Point, - pub cursor_path: Arc, + pub prediction: Result, } -impl EditPrediction { +impl EditPredictionResult { pub async fn new( id: EditPredictionId, edited_buffer: &Entity, @@ -55,8 +41,15 @@ impl EditPrediction { response_received_at: Instant, inputs: EditPredictionInputs, cx: &mut AsyncApp, - ) -> Option { - let (edits, snapshot, edit_preview_task) = edited_buffer + ) -> Self { + if edits.is_empty() { + return Self { + id, + prediction: Err(EditPredictionRejectReason::Empty), + }; + } + + let Some((edits, snapshot, edit_preview_task)) = edited_buffer .read_with(cx, |buffer, cx| { let new_snapshot = buffer.snapshot(); let edits: Arc<[_]> = @@ -64,22 +57,54 @@ impl EditPrediction { Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) }) - .ok()??; + .ok() + .flatten() + else { + return Self { + id, + prediction: Err(EditPredictionRejectReason::InterpolatedEmpty), + }; + }; let edit_preview = edit_preview_task.await; - Some(EditPrediction { - id, - edits, - snapshot, - edit_preview, - inputs, - buffer: edited_buffer.clone(), - buffer_snapshotted_at, - response_received_at, - }) + Self { + id: id.clone(), + prediction: Ok(EditPrediction { + id, + edits, + snapshot, + edit_preview, + inputs, + buffer: edited_buffer.clone(), + buffer_snapshotted_at, + response_received_at, + }), + } } +} +#[derive(Clone)] +pub struct EditPrediction { + pub id: EditPredictionId, + pub edits: Arc<[(Range, Arc)]>, + pub snapshot: BufferSnapshot, + pub edit_preview: EditPreview, + pub buffer: Entity, + pub buffer_snapshotted_at: Instant, + pub response_received_at: Instant, + pub inputs: EditPredictionInputs, +} + +#[derive(Debug, Clone, Serialize)] +pub struct EditPredictionInputs { + pub events: Vec>, + pub included_files: Vec, + pub cursor_point: cloud_llm_client::predict_edits_v3::Point, + pub cursor_path: Arc, +} + +impl EditPrediction { pub fn interpolate( &self, new_snapshot: &TextBufferSnapshot, diff --git a/crates/zeta/src/provider.rs b/crates/zeta/src/provider.rs index 76c950714afa808ea04cf5fead89979374f2b99b..b91df0963386543fbd1e8645e5893a35fe202cc5 100644 --- a/crates/zeta/src/provider.rs +++ b/crates/zeta/src/provider.rs @@ -1,6 +1,7 @@ use std::{cmp, sync::Arc, time::Duration}; use client::{Client, UserStore}; +use cloud_llm_client::EditPredictionRejectReason; use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; use gpui::{App, Entity, prelude::*}; use language::ToPoint as _; @@ -132,7 +133,11 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { fn discard(&mut self, cx: &mut Context) { self.zeta.update(cx, |zeta, cx| { - zeta.discard_current_prediction(&self.project, cx); + zeta.reject_current_prediction( + EditPredictionRejectReason::Discarded, + &self.project, + cx, + ); }); } @@ -169,7 +174,11 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let Some(edits) = prediction.interpolate(&snapshot) else { self.zeta.update(cx, |zeta, cx| { - zeta.discard_current_prediction(&self.project, cx); + zeta.reject_current_prediction( + EditPredictionRejectReason::InterpolatedEmpty, + &self.project, + cx, + ); }); return None; }; diff --git a/crates/zeta/src/sweep_ai.rs b/crates/zeta/src/sweep_ai.rs index c88dda2ae2fd11dd37965e58560df9e98528c9d9..f40e9711f231523174a2d2edbd9fe1adb14ad498 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/zeta/src/sweep_ai.rs @@ -18,7 +18,7 @@ use std::{ time::Instant, }; -use crate::{EditPrediction, EditPredictionId, EditPredictionInputs}; +use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult}; const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; @@ -45,7 +45,7 @@ impl SweepAi { recent_paths: &VecDeque, diagnostic_search_range: Range, cx: &mut App, - ) -> Task>> { + ) -> Task>> { let debug_info = self.debug_info.clone(); let Some(api_token) = self.api_token.clone() else { return Task::ready(Ok(None)); @@ -242,8 +242,8 @@ impl SweepAi { cx.spawn(async move |cx| { let (id, edits, old_snapshot, response_received_at, inputs) = result.await?; - anyhow::Ok( - EditPrediction::new( + anyhow::Ok(Some( + EditPredictionResult::new( EditPredictionId(id.into()), &buffer, &old_snapshot, @@ -254,7 +254,7 @@ impl SweepAi { cx, ) .await, - ) + )) }) } } diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 26a2388a96e4a828fc4c7bd6fe5d3dbb57bfc911..5cf0191e2f8180ea7bcfbef07c046372d2ee22c9 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -3,9 +3,9 @@ use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature}; use cloud_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection, - MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME, + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, + EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES}; @@ -74,6 +74,7 @@ use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; pub use crate::prediction::EditPredictionInputs; +use crate::prediction::EditPredictionResult; use crate::rate_prediction_modal::{ NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction, ThumbsUpActivePrediction, @@ -310,6 +311,31 @@ impl ZetaProject { ) .collect() } + + fn cancel_pending_prediction( + &mut self, + pending_prediction: PendingPrediction, + cx: &mut Context, + ) { + self.cancelled_predictions.insert(pending_prediction.id); + + cx.spawn(async move |this, cx| { + let Some(prediction_id) = pending_prediction.task.await else { + return; + }; + + this.update(cx, |this, cx| { + this.reject_prediction( + prediction_id, + EditPredictionRejectReason::Canceled, + false, + cx, + ); + }) + .ok(); + }) + .detach() + } } #[derive(Debug, Clone)] @@ -373,6 +399,7 @@ impl PredictionRequestedBy { } } +#[derive(Debug)] struct PendingPrediction { id: usize, task: Task>, @@ -385,6 +412,18 @@ enum BufferEditPrediction<'a> { Jump { prediction: &'a EditPrediction }, } +#[cfg(test)] +impl std::ops::Deref for BufferEditPrediction<'_> { + type Target = EditPrediction; + + fn deref(&self) -> &Self::Target { + match self { + BufferEditPrediction::Local { prediction } => prediction, + BufferEditPrediction::Jump { prediction } => prediction, + } + } +} + struct RegisteredBuffer { snapshot: BufferSnapshot, _subscriptions: [gpui::Subscription; 2], @@ -467,7 +506,7 @@ impl Zeta { let (reject_tx, mut reject_rx) = mpsc::unbounded(); cx.spawn(async move |this, cx| { while let Some(()) = reject_rx.next().await { - this.update(cx, |this, cx| this.reject_edit_predictions(cx))? + this.update(cx, |this, cx| this.flush_rejected_predictions(cx))? .await .log_err(); } @@ -818,7 +857,7 @@ impl Zeta { }; let request_id = prediction.prediction.id.to_string(); for pending_prediction in mem::take(&mut project_state.pending_predictions) { - self.cancel_pending_prediction(pending_prediction, cx); + project_state.cancel_pending_prediction(pending_prediction, cx); } let client = self.client.clone(); @@ -856,7 +895,7 @@ impl Zeta { .detach_and_log_err(cx); } - fn reject_edit_predictions(&mut self, cx: &mut Context) -> Task> { + fn flush_rejected_predictions(&mut self, cx: &mut Context) -> Task> { match self.edit_prediction_model { ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())), @@ -904,11 +943,16 @@ impl Zeta { }) } - fn discard_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + fn reject_current_prediction( + &mut self, + reason: EditPredictionRejectReason, + project: &Entity, + cx: &mut Context, + ) { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.pending_predictions.clear(); if let Some(prediction) = project_state.current_prediction.take() { - self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx); + self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx); } }; } @@ -929,14 +973,16 @@ impl Zeta { } } - fn discard_prediction( + fn reject_prediction( &mut self, prediction_id: EditPredictionId, + reason: EditPredictionRejectReason, was_shown: bool, cx: &mut Context, ) { self.rejected_predictions.push(EditPredictionRejection { request_id: prediction_id.to_string(), + reason, was_shown, }); @@ -944,34 +990,16 @@ impl Zeta { self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST; let reject_tx = self.reject_predictions_tx.clone(); self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| { - const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15); + const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); if !reached_request_limit { cx.background_executor() - .timer(DISCARD_COMPLETIONS_DEBOUNCE) + .timer(REJECT_REQUEST_DEBOUNCE) .await; } reject_tx.unbounded_send(()).log_err(); })); } - fn cancel_pending_prediction( - &self, - pending_prediction: PendingPrediction, - cx: &mut Context, - ) { - cx.spawn(async move |this, cx| { - let Some(prediction_id) = pending_prediction.task.await else { - return; - }; - - this.update(cx, |this, cx| { - this.discard_prediction(prediction_id, false, cx); - }) - .ok(); - }) - .detach() - } - fn is_refreshing(&self, project: &Entity) -> bool { self.projects .get(&project.entity_id()) @@ -995,38 +1023,15 @@ impl Zeta { return Task::ready(anyhow::Ok(None)); }; - let project = project.clone(); - cx.spawn(async move |cx| { - if let Some(prediction) = request_task.await? { - let id = prediction.id.clone(); - this.update(cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.entity_id()) - .context("Project not found")?; - - let new_prediction = CurrentEditPrediction { - requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()), - prediction: prediction, - was_shown: false, - }; - - if project_state - .current_prediction - .as_ref() - .is_none_or(|old_prediction| { - new_prediction.should_replace_prediction(&old_prediction, cx) - }) - { - project_state.current_prediction = Some(new_prediction); - cx.notify(); - } - anyhow::Ok(()) - })??; - Ok(Some(id)) - } else { - Ok(None) - } + cx.spawn(async move |_cx| { + request_task.await.map(|prediction_result| { + prediction_result.map(|prediction_result| { + ( + prediction_result, + PredictionRequestedBy::Buffer(buffer.entity_id()), + ) + }) + }) }) }) } @@ -1076,7 +1081,7 @@ impl Zeta { return anyhow::Ok(None); }; - let Some(prediction) = this + let Some(prediction_result) = this .update(cx, |this, cx| { this.request_prediction(&project, &jump_buffer, jump_position, cx) })? @@ -1085,21 +1090,23 @@ impl Zeta { return anyhow::Ok(None); }; - let id = prediction.id.clone(); this.update(cx, |this, cx| { - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - zeta_project.current_prediction.get_or_insert_with(|| { - cx.notify(); - CurrentEditPrediction { - requested_by: PredictionRequestedBy::DiagnosticsUpdate, - prediction, - was_shown: false, + Some(( + if this + .get_or_init_zeta_project(&project, cx) + .current_prediction + .is_none() + { + prediction_result + } else { + EditPredictionResult { + id: prediction_result.id, + prediction: Err(EditPredictionRejectReason::CurrentPreferred), } - }); - } - })?; - - anyhow::Ok(Some(id)) + }, + PredictionRequestedBy::DiagnosticsUpdate, + )) + }) }) }); } @@ -1117,7 +1124,8 @@ impl Zeta { do_refresh: impl FnOnce( WeakEntity, &mut AsyncApp, - ) -> Task>> + ) + -> Task>> + 'static, ) { let zeta_project = self.get_or_init_zeta_project(&project, cx); @@ -1152,22 +1160,77 @@ impl Zeta { return None; } - let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten(); + let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten(); + let new_prediction_id = new_prediction_result + .as_ref() + .map(|(prediction, _)| prediction.id.clone()); // When a prediction completes, remove it from the pending list, and cancel // any pending predictions that were enqueued before it. this.update(cx, |this, cx| { let zeta_project = this.get_or_init_zeta_project(&project, cx); - zeta_project + + let is_cancelled = zeta_project .cancelled_predictions .remove(&pending_prediction_id); + let new_current_prediction = if !is_cancelled + && let Some((prediction_result, requested_by)) = new_prediction_result + { + match prediction_result.prediction { + Ok(prediction) => { + let new_prediction = CurrentEditPrediction { + requested_by, + prediction, + was_shown: false, + }; + + if let Some(current_prediction) = + zeta_project.current_prediction.as_ref() + { + if new_prediction.should_replace_prediction(¤t_prediction, cx) + { + this.reject_current_prediction( + EditPredictionRejectReason::Replaced, + &project, + cx, + ); + + Some(new_prediction) + } else { + this.reject_prediction( + new_prediction.prediction.id, + EditPredictionRejectReason::CurrentPreferred, + false, + cx, + ); + None + } + } else { + Some(new_prediction) + } + } + Err(reject_reason) => { + this.reject_prediction(prediction_result.id, reject_reason, false, cx); + None + } + } + } else { + None + }; + + let zeta_project = this.get_or_init_zeta_project(&project, cx); + + if let Some(new_prediction) = new_current_prediction { + zeta_project.current_prediction = Some(new_prediction); + } + let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions); for (ix, pending_prediction) in pending_predictions.iter().enumerate() { if pending_prediction.id == pending_prediction_id { pending_predictions.remove(ix); for pending_prediction in pending_predictions.drain(0..ix) { - this.cancel_pending_prediction(pending_prediction, cx) + zeta_project.cancel_pending_prediction(pending_prediction, cx) } break; } @@ -1178,7 +1241,7 @@ impl Zeta { }) .ok(); - edit_prediction_id + new_prediction_id }); if zeta_project.pending_predictions.len() <= 1 { @@ -1192,10 +1255,7 @@ impl Zeta { id: pending_prediction_id, task, }); - zeta_project - .cancelled_predictions - .insert(pending_prediction.id); - self.cancel_pending_prediction(pending_prediction, cx); + zeta_project.cancel_pending_prediction(pending_prediction, cx); } } @@ -1205,7 +1265,7 @@ impl Zeta { active_buffer: &Entity, position: language::Anchor, cx: &mut Context, - ) -> Task>> { + ) -> Task>> { self.request_prediction_internal( project.clone(), active_buffer.clone(), @@ -1222,7 +1282,7 @@ impl Zeta { position: language::Anchor, allow_jump: bool, cx: &mut Context, - ) -> Task>> { + ) -> Task>> { const DIAGNOSTIC_LINES_RANGE: u32 = 20; self.get_or_init_zeta_project(&project, cx); @@ -1268,9 +1328,7 @@ impl Zeta { }; cx.spawn(async move |this, cx| { - let prediction = task - .await? - .filter(|prediction| !prediction.edits.is_empty()); + let prediction = task.await?; if prediction.is_none() && allow_jump { let cursor_point = position.to_point(&snapshot); @@ -1392,7 +1450,7 @@ impl Zeta { position: language::Anchor, events: Vec>, cx: &mut Context, - ) -> Task>> { + ) -> Task>> { let project_state = self.projects.get(&project.entity_id()); let index_state = project_state.and_then(|state| { @@ -1689,7 +1747,7 @@ impl Zeta { let (res, usage) = response?; let request_id = EditPredictionId(res.id.clone().into()); let Some(mut output_text) = text_from_response(res) else { - return Ok((None, usage)); + return Ok((Some((request_id, None)), usage)); }; if output_text.contains(CURSOR_MARKER) { @@ -1747,11 +1805,13 @@ impl Zeta { anyhow::Ok(( Some(( request_id, - inputs, - edited_buffer, - edited_buffer_snapshot.clone(), - edits, - received_response_at, + Some(( + inputs, + edited_buffer, + edited_buffer_snapshot.clone(), + edits, + received_response_at, + )), )), usage, )) @@ -1760,30 +1820,40 @@ impl Zeta { cx.spawn({ async move |this, cx| { + let Some((id, prediction)) = + Self::handle_api_response(&this, request_task.await, cx)? + else { + return Ok(None); + }; + let Some(( - id, inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at, - )) = Self::handle_api_response(&this, request_task.await, cx)? + )) = prediction else { - return Ok(None); + return Ok(Some(EditPredictionResult { + id, + prediction: Err(EditPredictionRejectReason::Empty), + })); }; // TODO telemetry: duration, etc - Ok(EditPrediction::new( - id, - &edited_buffer, - &edited_buffer_snapshot, - edits.into(), - buffer_snapshotted_at, - received_response_at, - inputs, - cx, - ) - .await) + Ok(Some( + EditPredictionResult::new( + id, + &edited_buffer, + &edited_buffer_snapshot, + edits.into(), + buffer_snapshotted_at, + received_response_at, + inputs, + cx, + ) + .await, + )) } }) } @@ -2806,6 +2876,9 @@ mod tests { use client::UserStore; use clock::FakeSystemClock; + use cloud_llm_client::{ + EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, + }; use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use futures::{ AsyncReadExt, StreamExt, @@ -2830,7 +2903,7 @@ mod tests { #[gpui::test] async fn test_current_state(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); + let (zeta, mut requests) = init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/root", @@ -2861,7 +2934,7 @@ mod tests { zeta.update(cx, |zeta, cx| { zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(model_response(indoc! {r" @@ -2888,7 +2961,7 @@ mod tests { let refresh_task = zeta.update(cx, |zeta, cx| { zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(open_ai::Response { id: Uuid::new_v4().to_string(), @@ -2929,14 +3002,14 @@ mod tests { refresh_task.await.unwrap(); zeta.update(cx, |zeta, cx| { - zeta.discard_current_prediction(&project, cx); + zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx); }); // Prediction for another file zeta.update(cx, |zeta, cx| { zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(model_response(indoc! {r#" --- a/root/2.txt @@ -2977,7 +3050,7 @@ mod tests { #[gpui::test] async fn test_simple_request(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); + let (zeta, mut requests) = init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/root", @@ -3002,7 +3075,7 @@ mod tests { zeta.request_prediction(&project, &buffer, position, cx) }); - let (_, respond_tx) = req_rx.next().await.unwrap(); + let (_, respond_tx) = requests.predict.next().await.unwrap(); // TODO Put back when we have a structured request again // assert_eq!( @@ -3029,7 +3102,7 @@ mod tests { "})) .unwrap(); - let prediction = prediction_task.await.unwrap().unwrap(); + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); assert_eq!(prediction.edits.len(), 1); assert_eq!( @@ -3041,7 +3114,7 @@ mod tests { #[gpui::test] async fn test_request_events(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); + let (zeta, mut requests) = init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/root", @@ -3075,7 +3148,7 @@ mod tests { zeta.request_prediction(&project, &buffer, position, cx) }); - let (request, respond_tx) = req_rx.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); let prompt = prompt_from_request(&request); assert!( @@ -3103,7 +3176,7 @@ mod tests { "#})) .unwrap(); - let prediction = prediction_task.await.unwrap().unwrap(); + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); assert_eq!(prediction.edits.len(), 1); assert_eq!( @@ -3113,6 +3186,522 @@ mod tests { assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); } + #[gpui::test] + async fn test_empty_prediction(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + const NO_OP_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How + Bye + "}; + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let response = model_response(NO_OP_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + assert!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::Empty, + was_shown: false + }] + ); + } + + #[gpui::test] + async fn test_interpolated_empty(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + + buffer.update(cx, |buffer, cx| { + buffer.set_text("Hello!\nHow are you?\nBye", cx); + }); + + let response = model_response(SIMPLE_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + assert!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::InterpolatedEmpty, + was_shown: false + }] + ); + } + + const SIMPLE_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "}; + + #[gpui::test] + async fn test_replace_current(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // second replaces first + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as replaced + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + }] + ); + } + + #[gpui::test] + async fn test_current_preferred(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + // worse than current prediction + let second_response = model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are + Bye + "}); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // first is preferred over second + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: second_id, + reason: EditPredictionRejectReason::CurrentPreferred, + was_shown: false + }] + ); + } + + #[gpui::test] + async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + // start two refresh tasks + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + let (_, respond_second) = requests.predict.next().await.unwrap(); + + // wait for throttle + cx.run_until_parked(); + + // second responds first + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_second.send(second_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // current prediction is second + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // current prediction is still second, since first was cancelled + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }] + ); + } + + #[gpui::test] + async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + zeta.update(cx, |zeta, cx| { + // start two refresh tasks + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + // wait for throttle, so requests are sent + cx.run_until_parked(); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + let (_, respond_second) = requests.predict.next().await.unwrap(); + + zeta.update(cx, |zeta, cx| { + // start a third request + zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + + // 2 are pending, so 2nd is cancelled + assert_eq!( + zeta.get_or_init_zeta_project(&project, cx) + .cancelled_predictions + .iter() + .copied() + .collect::>(), + [1] + ); + }); + + // wait for throttle + cx.run_until_parked(); + + let (_, respond_third) = requests.predict.next().await.unwrap(); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // current prediction is first + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let cancelled_response = model_response(SIMPLE_DIFF); + let cancelled_id = cancelled_response.id.clone(); + respond_second.send(cancelled_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // current prediction is still first, since second was cancelled + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let third_response = model_response(SIMPLE_DIFF); + let third_response_id = third_response.id.clone(); + respond_third.send(third_response).unwrap(); + + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + // third completes and replaces first + assert_eq!( + zeta.current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + third_response_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[ + EditPredictionRejection { + request_id: cancelled_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }, + EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + } + ] + ); + } + // Skipped until we start including diagnostics in prompt // #[gpui::test] // async fn test_request_diagnostics(cx: &mut TestAppContext) { @@ -3242,24 +3831,26 @@ mod tests { content } - fn init_test( - cx: &mut TestAppContext, - ) -> ( - Entity, - mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, - ) { + struct RequestChannels { + predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, + reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, + } + + fn init_test(cx: &mut TestAppContext) -> (Entity, RequestChannels) { cx.update(move |cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); zlog::init_test(); - let (req_tx, req_rx) = mpsc::unbounded(); + let (predict_req_tx, predict_req_rx) = mpsc::unbounded(); + let (reject_req_tx, reject_req_rx) = mpsc::unbounded(); let http_client = FakeHttpClient::create({ move |req| { let uri = req.uri().path().to_string(); let mut body = req.into_body(); - let req_tx = req_tx.clone(); + let predict_req_tx = predict_req_tx.clone(); + let reject_req_tx = reject_req_tx.clone(); async move { let resp = match uri.as_str() { "/client/llm_tokens" => serde_json::to_string(&json!({ @@ -3272,7 +3863,16 @@ mod tests { let req = serde_json::from_slice(&buf).unwrap(); let (res_tx, res_rx) = oneshot::channel(); - req_tx.unbounded_send((req, res_tx)).unwrap(); + predict_req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + "/predict_edits/reject" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + reject_req_tx.unbounded_send((req, res_tx)).unwrap(); serde_json::to_string(&res_rx.await?).unwrap() } _ => { @@ -3293,7 +3893,13 @@ mod tests { let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let zeta = Zeta::global(&client, &user_store, cx); - (zeta, req_rx) + ( + zeta, + RequestChannels { + predict: predict_req_rx, + reject: reject_req_rx, + }, + ) }) } } diff --git a/crates/zeta/src/zeta1.rs b/crates/zeta/src/zeta1.rs index 7f80d60d5efcbbd0bd7b9426508c344c063d5597..96d175d5eb11c2c8be40779cf77bfb743d39dff6 100644 --- a/crates/zeta/src/zeta1.rs +++ b/crates/zeta/src/zeta1.rs @@ -4,7 +4,7 @@ use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ EditPredictionId, ZedUpdateRequiredError, Zeta, - prediction::{EditPrediction, EditPredictionInputs}, + prediction::{EditPredictionInputs, EditPredictionResult}, }; use anyhow::{Context as _, Result}; use cloud_llm_client::{ @@ -36,7 +36,7 @@ pub(crate) fn request_prediction_with_zeta1( position: language::Anchor, events: Vec>, cx: &mut Context, -) -> Task>> { +) -> Task>> { let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); let client = zeta.client.clone(); @@ -216,7 +216,7 @@ pub(crate) fn request_prediction_with_zeta1( ); } - edit_prediction + edit_prediction.map(Some) }) } @@ -229,7 +229,7 @@ fn process_completion_response( buffer_snapshotted_at: Instant, received_response_at: Instant, cx: &AsyncApp, -) -> Task>> { +) -> Task> { let snapshot = snapshot.clone(); let request_id = prediction_response.request_id; let output_excerpt = prediction_response.output_excerpt; @@ -246,8 +246,9 @@ fn process_completion_response( .await? .into(); - Ok(EditPrediction::new( - EditPredictionId(request_id.into()), + let id = EditPredictionId(request_id.into()); + Ok(EditPredictionResult::new( + id, &buffer, &snapshot, edits, diff --git a/crates/zeta/src/zeta_tests.rs b/crates/zeta/src/zeta_tests.rs index eb12f81af25d72b5e7003187ab0a9536622c9a74..9b7abb216f5e8e7a9c8bd14a33c2f6ecd9f16174 100644 --- a/crates/zeta/src/zeta_tests.rs +++ b/crates/zeta/src/zeta_tests.rs @@ -538,7 +538,7 @@ async fn run_edit_prediction( let prediction_task = zeta.update(cx, |zeta, cx| { zeta.request_prediction(&project, buffer, cursor, cx) }); - prediction_task.await.unwrap().unwrap() + prediction_task.await.unwrap().unwrap().prediction.unwrap() } async fn make_test_zeta( diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 8a1a4131fb684a5186b2111f9d922fa34d6972e1..c2d68a471fa5de7765c1042473fc8118a3fc9415 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -235,7 +235,10 @@ pub async fn perform_predict( let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); result.diff = prediction - .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits)) + .and_then(|prediction| { + let prediction = prediction.prediction.ok()?; + prediction.edit_preview.as_unified_diff(&prediction.edits) + }) .unwrap_or_default(); anyhow::Ok(result)