diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index fd16478b5a7ade4b8ef86924d2ce737cb2f62c56..74b6687f62c641ce4076778efa4369a45529f4f9 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -251,11 +251,12 @@ fn assign_edit_prediction_provider( }); } - let provider = cx.new(|_| { + let provider = cx.new(|cx| { zeta::ZetaEditPredictionProvider::new( zeta, project.clone(), singleton_buffer, + cx, ) }); editor.set_edit_prediction_provider(Some(provider), window, cx); diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 577ca77c13c0b9f8e0eff578c20d0a933c858bce..c2ef5cb826db0947c18e1e91a6163cccc12deb11 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1283,6 +1283,7 @@ struct CurrentEditPrediction { buffer_id: EntityId, completion: EditPrediction, was_shown: bool, + was_accepted: bool, } impl CurrentEditPrediction { @@ -1310,7 +1311,7 @@ impl CurrentEditPrediction { struct PendingCompletion { id: usize, - _task: Task<()>, + task: Task<()>, } #[derive(Debug, Clone, Copy)] @@ -1386,6 +1387,7 @@ pub struct ZetaEditPredictionProvider { zeta: Entity, singleton_buffer: Option>, pending_completions: ArrayVec, + canceled_completions: HashMap>, next_pending_completion_id: usize, current_completion: Option, last_request_timestamp: Instant, @@ -1399,17 +1401,34 @@ impl ZetaEditPredictionProvider { zeta: Entity, project: Entity, singleton_buffer: Option>, + cx: &mut Context, ) -> Self { + cx.on_release(|this, cx| { + this.take_current_edit_prediction(cx); + }) + .detach(); + Self { zeta, singleton_buffer, pending_completions: ArrayVec::new(), + canceled_completions: HashMap::default(), next_pending_completion_id: 0, current_completion: None, last_request_timestamp: Instant::now(), project, } } + + fn take_current_edit_prediction(&mut self, cx: &mut App) { + if let Some(completion) = self.current_completion.take() { + if !completion.was_accepted { + self.zeta.update(cx, |zeta, cx| { + zeta.discard_completion(completion.completion.id, completion.was_shown, cx); + }); + } + } + } } impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { @@ -1531,42 +1550,65 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { buffer_id: buffer.entity_id(), completion, was_shown: false, + was_accepted: false, }) }) } Err(error) => Err(error), }; - let Some(new_completion) = completion - .context("edit prediction failed") - .log_err() - .flatten() - else { - this.update(cx, |this, cx| { - if this.pending_completions[0].id == pending_completion_id { + + let discarded = this + .update(cx, |this, cx| { + if this + .pending_completions + .first() + .is_some_and(|completion| completion.id == pending_completion_id) + { this.pending_completions.remove(0); } else { - this.pending_completions.clear(); + if let Some(discarded) = this.pending_completions.drain(..).next() { + this.canceled_completions + .insert(discarded.id, discarded.task); + } + } + + let canceled = this.canceled_completions.remove(&pending_completion_id); + + if canceled.is_some() + && let Ok(Some(new_completion)) = &completion + { + this.zeta.update(cx, |zeta, cx| { + zeta.discard_completion(new_completion.completion.id, false, cx); + }); + return true; } cx.notify(); + false }) - .ok(); + .ok() + .unwrap_or(true); + + if discarded { + return; + } + + let Some(new_completion) = completion + .context("edit prediction failed") + .log_err() + .flatten() + else { return; }; this.update(cx, |this, cx| { - if this.pending_completions[0].id == pending_completion_id { - this.pending_completions.remove(0); - } else { - this.pending_completions.clear(); - } - if let Some(old_completion) = this.current_completion.as_ref() { let snapshot = buffer.read(cx).snapshot(); if new_completion.should_replace_completion(old_completion, &snapshot) { this.zeta.update(cx, |zeta, cx| { zeta.completion_shown(&new_completion.completion, cx); }); + this.take_current_edit_prediction(cx); this.current_completion = Some(new_completion); } } else { @@ -1586,13 +1628,16 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { if self.pending_completions.len() <= 1 { self.pending_completions.push(PendingCompletion { id: pending_completion_id, - _task: task, + task, }); } else if self.pending_completions.len() == 2 { - self.pending_completions.pop(); + if let Some(discarded) = self.pending_completions.pop() { + self.canceled_completions + .insert(discarded.id, discarded.task); + } self.pending_completions.push(PendingCompletion { id: pending_completion_id, - _task: task, + task, }); } } @@ -1608,14 +1653,12 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } fn accept(&mut self, cx: &mut Context) { - let completion_id = self - .current_completion - .as_ref() - .map(|completion| completion.completion.id); - if let Some(completion_id) = completion_id { + let completion = self.current_completion.as_mut(); + if let Some(completion) = completion { + completion.was_accepted = true; self.zeta .update(cx, |zeta, cx| { - zeta.accept_edit_prediction(completion_id, cx) + zeta.accept_edit_prediction(completion.completion.id, cx) }) .detach(); } @@ -1624,11 +1667,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { fn discard(&mut self, cx: &mut Context) { self.pending_completions.clear(); - if let Some(completion) = self.current_completion.take() { - self.zeta.update(cx, |zeta, cx| { - zeta.discard_completion(completion.completion.id, completion.was_shown, cx); - }); - } + self.take_current_edit_prediction(cx); } fn did_show(&mut self, _cx: &mut Context) { @@ -1651,13 +1690,13 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { // Invalidate previous completion if it was generated for a different buffer. if *buffer_id != buffer.entity_id() { - self.current_completion.take(); + self.take_current_edit_prediction(cx); return None; } let buffer = buffer.read(cx); let Some(edits) = completion.interpolate(&buffer.snapshot()) else { - self.current_completion.take(); + self.take_current_edit_prediction(cx); return None; };