@@ -3,9 +3,9 @@ use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
use cloud_llm_client::{
- AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
- MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
+ EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
+ MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
@@ -74,6 +74,7 @@ use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
pub use crate::prediction::EditPredictionInputs;
+use crate::prediction::EditPredictionResult;
use crate::rate_prediction_modal::{
NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
ThumbsUpActivePrediction,
@@ -310,6 +311,31 @@ impl ZetaProject {
)
.collect()
}
+
+ fn cancel_pending_prediction(
+ &mut self,
+ pending_prediction: PendingPrediction,
+ cx: &mut Context<Zeta>,
+ ) {
+ self.cancelled_predictions.insert(pending_prediction.id);
+
+ cx.spawn(async move |this, cx| {
+ let Some(prediction_id) = pending_prediction.task.await else {
+ return;
+ };
+
+ this.update(cx, |this, cx| {
+ this.reject_prediction(
+ prediction_id,
+ EditPredictionRejectReason::Canceled,
+ false,
+ cx,
+ );
+ })
+ .ok();
+ })
+ .detach()
+ }
}
#[derive(Debug, Clone)]
@@ -373,6 +399,7 @@ impl PredictionRequestedBy {
}
}
+#[derive(Debug)]
struct PendingPrediction {
id: usize,
task: Task<Option<EditPredictionId>>,
@@ -385,6 +412,18 @@ enum BufferEditPrediction<'a> {
Jump { prediction: &'a EditPrediction },
}
+#[cfg(test)]
+impl std::ops::Deref for BufferEditPrediction<'_> {
+ type Target = EditPrediction;
+
+ fn deref(&self) -> &Self::Target {
+ match self {
+ BufferEditPrediction::Local { prediction } => prediction,
+ BufferEditPrediction::Jump { prediction } => prediction,
+ }
+ }
+}
+
struct RegisteredBuffer {
snapshot: BufferSnapshot,
_subscriptions: [gpui::Subscription; 2],
@@ -467,7 +506,7 @@ impl Zeta {
let (reject_tx, mut reject_rx) = mpsc::unbounded();
cx.spawn(async move |this, cx| {
while let Some(()) = reject_rx.next().await {
- this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
+ this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
.await
.log_err();
}
@@ -818,7 +857,7 @@ impl Zeta {
};
let request_id = prediction.prediction.id.to_string();
for pending_prediction in mem::take(&mut project_state.pending_predictions) {
- self.cancel_pending_prediction(pending_prediction, cx);
+ project_state.cancel_pending_prediction(pending_prediction, cx);
}
let client = self.client.clone();
@@ -856,7 +895,7 @@ impl Zeta {
.detach_and_log_err(cx);
}
- fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
match self.edit_prediction_model {
ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
@@ -904,11 +943,16 @@ impl Zeta {
})
}
- fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ fn reject_current_prediction(
+ &mut self,
+ reason: EditPredictionRejectReason,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.pending_predictions.clear();
if let Some(prediction) = project_state.current_prediction.take() {
- self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
+ self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
}
};
}
@@ -929,14 +973,16 @@ impl Zeta {
}
}
- fn discard_prediction(
+ fn reject_prediction(
&mut self,
prediction_id: EditPredictionId,
+ reason: EditPredictionRejectReason,
was_shown: bool,
cx: &mut Context<Self>,
) {
self.rejected_predictions.push(EditPredictionRejection {
request_id: prediction_id.to_string(),
+ reason,
was_shown,
});
@@ -944,34 +990,16 @@ impl Zeta {
self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
let reject_tx = self.reject_predictions_tx.clone();
self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
- const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
+ const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
if !reached_request_limit {
cx.background_executor()
- .timer(DISCARD_COMPLETIONS_DEBOUNCE)
+ .timer(REJECT_REQUEST_DEBOUNCE)
.await;
}
reject_tx.unbounded_send(()).log_err();
}));
}
- fn cancel_pending_prediction(
- &self,
- pending_prediction: PendingPrediction,
- cx: &mut Context<Self>,
- ) {
- cx.spawn(async move |this, cx| {
- let Some(prediction_id) = pending_prediction.task.await else {
- return;
- };
-
- this.update(cx, |this, cx| {
- this.discard_prediction(prediction_id, false, cx);
- })
- .ok();
- })
- .detach()
- }
-
fn is_refreshing(&self, project: &Entity<Project>) -> bool {
self.projects
.get(&project.entity_id())
@@ -995,38 +1023,15 @@ impl Zeta {
return Task::ready(anyhow::Ok(None));
};
- let project = project.clone();
- cx.spawn(async move |cx| {
- if let Some(prediction) = request_task.await? {
- let id = prediction.id.clone();
- this.update(cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project.entity_id())
- .context("Project not found")?;
-
- let new_prediction = CurrentEditPrediction {
- requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
- prediction: prediction,
- was_shown: false,
- };
-
- if project_state
- .current_prediction
- .as_ref()
- .is_none_or(|old_prediction| {
- new_prediction.should_replace_prediction(&old_prediction, cx)
- })
- {
- project_state.current_prediction = Some(new_prediction);
- cx.notify();
- }
- anyhow::Ok(())
- })??;
- Ok(Some(id))
- } else {
- Ok(None)
- }
+ cx.spawn(async move |_cx| {
+ request_task.await.map(|prediction_result| {
+ prediction_result.map(|prediction_result| {
+ (
+ prediction_result,
+ PredictionRequestedBy::Buffer(buffer.entity_id()),
+ )
+ })
+ })
})
})
}
@@ -1076,7 +1081,7 @@ impl Zeta {
return anyhow::Ok(None);
};
- let Some(prediction) = this
+ let Some(prediction_result) = this
.update(cx, |this, cx| {
this.request_prediction(&project, &jump_buffer, jump_position, cx)
})?
@@ -1085,21 +1090,23 @@ impl Zeta {
return anyhow::Ok(None);
};
- let id = prediction.id.clone();
this.update(cx, |this, cx| {
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
- zeta_project.current_prediction.get_or_insert_with(|| {
- cx.notify();
- CurrentEditPrediction {
- requested_by: PredictionRequestedBy::DiagnosticsUpdate,
- prediction,
- was_shown: false,
+ Some((
+ if this
+ .get_or_init_zeta_project(&project, cx)
+ .current_prediction
+ .is_none()
+ {
+ prediction_result
+ } else {
+ EditPredictionResult {
+ id: prediction_result.id,
+ prediction: Err(EditPredictionRejectReason::CurrentPreferred),
}
- });
- }
- })?;
-
- anyhow::Ok(Some(id))
+ },
+ PredictionRequestedBy::DiagnosticsUpdate,
+ ))
+ })
})
});
}
@@ -1117,7 +1124,8 @@ impl Zeta {
do_refresh: impl FnOnce(
WeakEntity<Self>,
&mut AsyncApp,
- ) -> Task<Result<Option<EditPredictionId>>>
+ )
+ -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
+ 'static,
) {
let zeta_project = self.get_or_init_zeta_project(&project, cx);
@@ -1152,22 +1160,77 @@ impl Zeta {
return None;
}
- let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
+ let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
+ let new_prediction_id = new_prediction_result
+ .as_ref()
+ .map(|(prediction, _)| prediction.id.clone());
// When a prediction completes, remove it from the pending list, and cancel
// any pending predictions that were enqueued before it.
this.update(cx, |this, cx| {
let zeta_project = this.get_or_init_zeta_project(&project, cx);
- zeta_project
+
+ let is_cancelled = zeta_project
.cancelled_predictions
.remove(&pending_prediction_id);
+ let new_current_prediction = if !is_cancelled
+ && let Some((prediction_result, requested_by)) = new_prediction_result
+ {
+ match prediction_result.prediction {
+ Ok(prediction) => {
+ let new_prediction = CurrentEditPrediction {
+ requested_by,
+ prediction,
+ was_shown: false,
+ };
+
+ if let Some(current_prediction) =
+ zeta_project.current_prediction.as_ref()
+ {
+ if new_prediction.should_replace_prediction(¤t_prediction, cx)
+ {
+ this.reject_current_prediction(
+ EditPredictionRejectReason::Replaced,
+ &project,
+ cx,
+ );
+
+ Some(new_prediction)
+ } else {
+ this.reject_prediction(
+ new_prediction.prediction.id,
+ EditPredictionRejectReason::CurrentPreferred,
+ false,
+ cx,
+ );
+ None
+ }
+ } else {
+ Some(new_prediction)
+ }
+ }
+ Err(reject_reason) => {
+ this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+ None
+ }
+ }
+ } else {
+ None
+ };
+
+ let zeta_project = this.get_or_init_zeta_project(&project, cx);
+
+ if let Some(new_prediction) = new_current_prediction {
+ zeta_project.current_prediction = Some(new_prediction);
+ }
+
let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
if pending_prediction.id == pending_prediction_id {
pending_predictions.remove(ix);
for pending_prediction in pending_predictions.drain(0..ix) {
- this.cancel_pending_prediction(pending_prediction, cx)
+ zeta_project.cancel_pending_prediction(pending_prediction, cx)
}
break;
}
@@ -1178,7 +1241,7 @@ impl Zeta {
})
.ok();
- edit_prediction_id
+ new_prediction_id
});
if zeta_project.pending_predictions.len() <= 1 {
@@ -1192,10 +1255,7 @@ impl Zeta {
id: pending_prediction_id,
task,
});
- zeta_project
- .cancelled_predictions
- .insert(pending_prediction.id);
- self.cancel_pending_prediction(pending_prediction, cx);
+ zeta_project.cancel_pending_prediction(pending_prediction, cx);
}
}
@@ -1205,7 +1265,7 @@ impl Zeta {
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
+ ) -> Task<Result<Option<EditPredictionResult>>> {
self.request_prediction_internal(
project.clone(),
active_buffer.clone(),
@@ -1222,7 +1282,7 @@ impl Zeta {
position: language::Anchor,
allow_jump: bool,
cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
+ ) -> Task<Result<Option<EditPredictionResult>>> {
const DIAGNOSTIC_LINES_RANGE: u32 = 20;
self.get_or_init_zeta_project(&project, cx);
@@ -1268,9 +1328,7 @@ impl Zeta {
};
cx.spawn(async move |this, cx| {
- let prediction = task
- .await?
- .filter(|prediction| !prediction.edits.is_empty());
+ let prediction = task.await?;
if prediction.is_none() && allow_jump {
let cursor_point = position.to_point(&snapshot);
@@ -1392,7 +1450,7 @@ impl Zeta {
position: language::Anchor,
events: Vec<Arc<Event>>,
cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPrediction>>> {
+ ) -> Task<Result<Option<EditPredictionResult>>> {
let project_state = self.projects.get(&project.entity_id());
let index_state = project_state.and_then(|state| {
@@ -1689,7 +1747,7 @@ impl Zeta {
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
- return Ok((None, usage));
+ return Ok((Some((request_id, None)), usage));
};
if output_text.contains(CURSOR_MARKER) {
@@ -1747,11 +1805,13 @@ impl Zeta {
anyhow::Ok((
Some((
request_id,
- inputs,
- edited_buffer,
- edited_buffer_snapshot.clone(),
- edits,
- received_response_at,
+ Some((
+ inputs,
+ edited_buffer,
+ edited_buffer_snapshot.clone(),
+ edits,
+ received_response_at,
+ )),
)),
usage,
))
@@ -1760,30 +1820,40 @@ impl Zeta {
cx.spawn({
async move |this, cx| {
+ let Some((id, prediction)) =
+ Self::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
+
let Some((
- id,
inputs,
edited_buffer,
edited_buffer_snapshot,
edits,
received_response_at,
- )) = Self::handle_api_response(&this, request_task.await, cx)?
+ )) = prediction
else {
- return Ok(None);
+ return Ok(Some(EditPredictionResult {
+ id,
+ prediction: Err(EditPredictionRejectReason::Empty),
+ }));
};
// TODO telemetry: duration, etc
- Ok(EditPrediction::new(
- id,
- &edited_buffer,
- &edited_buffer_snapshot,
- edits.into(),
- buffer_snapshotted_at,
- received_response_at,
- inputs,
- cx,
- )
- .await)
+ Ok(Some(
+ EditPredictionResult::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
}
})
}
@@ -2806,6 +2876,9 @@ mod tests {
use client::UserStore;
use clock::FakeSystemClock;
+ use cloud_llm_client::{
+ EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
+ };
use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
use futures::{
AsyncReadExt, StreamExt,
@@ -2830,7 +2903,7 @@ mod tests {
#[gpui::test]
async fn test_current_state(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
+ let (zeta, mut requests) = init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
@@ -2861,7 +2934,7 @@ mod tests {
zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = req_rx.next().await.unwrap();
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r"
@@ -2888,7 +2961,7 @@ mod tests {
let refresh_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = req_rx.next().await.unwrap();
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(open_ai::Response {
id: Uuid::new_v4().to_string(),
@@ -2929,14 +3002,14 @@ mod tests {
refresh_task.await.unwrap();
zeta.update(cx, |zeta, cx| {
- zeta.discard_current_prediction(&project, cx);
+ zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
});
// Prediction for another file
zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = req_rx.next().await.unwrap();
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r#"
--- a/root/2.txt
@@ -2977,7 +3050,7 @@ mod tests {
#[gpui::test]
async fn test_simple_request(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
+ let (zeta, mut requests) = init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
@@ -3002,7 +3075,7 @@ mod tests {
zeta.request_prediction(&project, &buffer, position, cx)
});
- let (_, respond_tx) = req_rx.next().await.unwrap();
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -3029,7 +3102,7 @@ mod tests {
"}))
.unwrap();
- let prediction = prediction_task.await.unwrap().unwrap();
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
assert_eq!(prediction.edits.len(), 1);
assert_eq!(
@@ -3041,7 +3114,7 @@ mod tests {
#[gpui::test]
async fn test_request_events(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
+ let (zeta, mut requests) = init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
@@ -3075,7 +3148,7 @@ mod tests {
zeta.request_prediction(&project, &buffer, position, cx)
});
- let (request, respond_tx) = req_rx.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
let prompt = prompt_from_request(&request);
assert!(
@@ -3103,7 +3176,7 @@ mod tests {
"#}))
.unwrap();
- let prediction = prediction_task.await.unwrap().unwrap();
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
assert_eq!(prediction.edits.len(), 1);
assert_eq!(
@@ -3113,6 +3186,522 @@ mod tests {
assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
+ #[gpui::test]
+ async fn test_empty_prediction(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ const NO_OP_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How
+ Bye
+ "};
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(NO_OP_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ assert!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::Empty,
+ was_shown: false
+ }]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_interpolated_empty(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_text("Hello!\nHow are you?\nBye", cx);
+ });
+
+ let response = model_response(SIMPLE_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ assert!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::InterpolatedEmpty,
+ was_shown: false
+ }]
+ );
+ }
+
+ const SIMPLE_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "};
+
+ #[gpui::test]
+ async fn test_replace_current(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // second replaces first
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as replaced
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_current_preferred(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ // worse than current prediction
+ let second_response = model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "});
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // first is preferred over second
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: second_id,
+ reason: EditPredictionRejectReason::CurrentPreferred,
+ was_shown: false
+ }]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ // start two refresh tasks
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ // second responds first
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_second.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // current prediction is second
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // current prediction is still second, since first was cancelled
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ }]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ zeta.update(cx, |zeta, cx| {
+ // start two refresh tasks
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ // wait for throttle, so requests are sent
+ cx.run_until_parked();
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ zeta.update(cx, |zeta, cx| {
+ // start a third request
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+ // 2 are pending, so 2nd is cancelled
+ assert_eq!(
+ zeta.get_or_init_zeta_project(&project, cx)
+ .cancelled_predictions
+ .iter()
+ .copied()
+ .collect::<Vec<_>>(),
+ [1]
+ );
+ });
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ let (_, respond_third) = requests.predict.next().await.unwrap();
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // current prediction is first
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_id = cancelled_response.id.clone();
+ respond_second.send(cancelled_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // current prediction is still first, since second was cancelled
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let third_response = model_response(SIMPLE_DIFF);
+ let third_response_id = third_response.id.clone();
+ respond_third.send(third_response).unwrap();
+
+ cx.run_until_parked();
+
+ zeta.read_with(cx, |zeta, cx| {
+ // third completes and replaces first
+ assert_eq!(
+ zeta.current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ third_response_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[
+ EditPredictionRejection {
+ request_id: cancelled_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ },
+ EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }
+ ]
+ );
+ }
+
// Skipped until we start including diagnostics in prompt
// #[gpui::test]
// async fn test_request_diagnostics(cx: &mut TestAppContext) {