Report automatically discarded zeta predictions (#42761)

Agus Zubiaga , Max Brunsfeld , and Ben Kunkle created

We weren't reporting predictions that were generated but never made it
out of the provider, such as predictions that failed to interpolate, and
those that are cancelled because another request completes before it.

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/zed/src/zed/edit_prediction_registry.rs |   3 
crates/zeta/src/zeta.rs                        | 103 +++++++++++++------
2 files changed, 73 insertions(+), 33 deletions(-)

Detailed changes

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -251,11 +251,12 @@ fn assign_edit_prediction_provider(
                             });
                         }
 
-                        let provider = cx.new(|_| {
+                        let provider = cx.new(|cx| {
                             zeta::ZetaEditPredictionProvider::new(
                                 zeta,
                                 project.clone(),
                                 singleton_buffer,
+                                cx,
                             )
                         });
                         editor.set_edit_prediction_provider(Some(provider), window, cx);

crates/zeta/src/zeta.rs 🔗

@@ -1283,6 +1283,7 @@ struct CurrentEditPrediction {
     buffer_id: EntityId,
     completion: EditPrediction,
     was_shown: bool,
+    was_accepted: bool,
 }
 
 impl CurrentEditPrediction {
@@ -1310,7 +1311,7 @@ impl CurrentEditPrediction {
 
 struct PendingCompletion {
     id: usize,
-    _task: Task<()>,
+    task: Task<()>,
 }
 
 #[derive(Debug, Clone, Copy)]
@@ -1386,6 +1387,7 @@ pub struct ZetaEditPredictionProvider {
     zeta: Entity<Zeta>,
     singleton_buffer: Option<Entity<Buffer>>,
     pending_completions: ArrayVec<PendingCompletion, 2>,
+    canceled_completions: HashMap<usize, Task<()>>,
     next_pending_completion_id: usize,
     current_completion: Option<CurrentEditPrediction>,
     last_request_timestamp: Instant,
@@ -1399,17 +1401,34 @@ impl ZetaEditPredictionProvider {
         zeta: Entity<Zeta>,
         project: Entity<Project>,
         singleton_buffer: Option<Entity<Buffer>>,
+        cx: &mut Context<Self>,
     ) -> Self {
+        cx.on_release(|this, cx| {
+            this.take_current_edit_prediction(cx);
+        })
+        .detach();
+
         Self {
             zeta,
             singleton_buffer,
             pending_completions: ArrayVec::new(),
+            canceled_completions: HashMap::default(),
             next_pending_completion_id: 0,
             current_completion: None,
             last_request_timestamp: Instant::now(),
             project,
         }
     }
+
+    fn take_current_edit_prediction(&mut self, cx: &mut App) {
+        if let Some(completion) = self.current_completion.take() {
+            if !completion.was_accepted {
+                self.zeta.update(cx, |zeta, cx| {
+                    zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
+                });
+            }
+        }
+    }
 }
 
 impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
@@ -1531,42 +1550,65 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
                             buffer_id: buffer.entity_id(),
                             completion,
                             was_shown: false,
+                            was_accepted: false,
                         })
                     })
                 }
                 Err(error) => Err(error),
             };
-            let Some(new_completion) = completion
-                .context("edit prediction failed")
-                .log_err()
-                .flatten()
-            else {
-                this.update(cx, |this, cx| {
-                    if this.pending_completions[0].id == pending_completion_id {
+
+            let discarded = this
+                .update(cx, |this, cx| {
+                    if this
+                        .pending_completions
+                        .first()
+                        .is_some_and(|completion| completion.id == pending_completion_id)
+                    {
                         this.pending_completions.remove(0);
                     } else {
-                        this.pending_completions.clear();
+                        if let Some(discarded) = this.pending_completions.drain(..).next() {
+                            this.canceled_completions
+                                .insert(discarded.id, discarded.task);
+                        }
+                    }
+
+                    let canceled = this.canceled_completions.remove(&pending_completion_id);
+
+                    if canceled.is_some()
+                        && let Ok(Some(new_completion)) = &completion
+                    {
+                        this.zeta.update(cx, |zeta, cx| {
+                            zeta.discard_completion(new_completion.completion.id, false, cx);
+                        });
+                        return true;
                     }
 
                     cx.notify();
+                    false
                 })
-                .ok();
+                .ok()
+                .unwrap_or(true);
+
+            if discarded {
+                return;
+            }
+
+            let Some(new_completion) = completion
+                .context("edit prediction failed")
+                .log_err()
+                .flatten()
+            else {
                 return;
             };
 
             this.update(cx, |this, cx| {
-                if this.pending_completions[0].id == pending_completion_id {
-                    this.pending_completions.remove(0);
-                } else {
-                    this.pending_completions.clear();
-                }
-
                 if let Some(old_completion) = this.current_completion.as_ref() {
                     let snapshot = buffer.read(cx).snapshot();
                     if new_completion.should_replace_completion(old_completion, &snapshot) {
                         this.zeta.update(cx, |zeta, cx| {
                             zeta.completion_shown(&new_completion.completion, cx);
                         });
+                        this.take_current_edit_prediction(cx);
                         this.current_completion = Some(new_completion);
                     }
                 } else {
@@ -1586,13 +1628,16 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
         if self.pending_completions.len() <= 1 {
             self.pending_completions.push(PendingCompletion {
                 id: pending_completion_id,
-                _task: task,
+                task,
             });
         } else if self.pending_completions.len() == 2 {
-            self.pending_completions.pop();
+            if let Some(discarded) = self.pending_completions.pop() {
+                self.canceled_completions
+                    .insert(discarded.id, discarded.task);
+            }
             self.pending_completions.push(PendingCompletion {
                 id: pending_completion_id,
-                _task: task,
+                task,
             });
         }
     }
@@ -1608,14 +1653,12 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
     }
 
     fn accept(&mut self, cx: &mut Context<Self>) {
-        let completion_id = self
-            .current_completion
-            .as_ref()
-            .map(|completion| completion.completion.id);
-        if let Some(completion_id) = completion_id {
+        let completion = self.current_completion.as_mut();
+        if let Some(completion) = completion {
+            completion.was_accepted = true;
             self.zeta
                 .update(cx, |zeta, cx| {
-                    zeta.accept_edit_prediction(completion_id, cx)
+                    zeta.accept_edit_prediction(completion.completion.id, cx)
                 })
                 .detach();
         }
@@ -1624,11 +1667,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
 
     fn discard(&mut self, cx: &mut Context<Self>) {
         self.pending_completions.clear();
-        if let Some(completion) = self.current_completion.take() {
-            self.zeta.update(cx, |zeta, cx| {
-                zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
-            });
-        }
+        self.take_current_edit_prediction(cx);
     }
 
     fn did_show(&mut self, _cx: &mut Context<Self>) {
@@ -1651,13 +1690,13 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
 
         // Invalidate previous completion if it was generated for a different buffer.
         if *buffer_id != buffer.entity_id() {
-            self.current_completion.take();
+            self.take_current_edit_prediction(cx);
             return None;
         }
 
         let buffer = buffer.read(cx);
         let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
-            self.current_completion.take();
+            self.take_current_edit_prediction(cx);
             return None;
         };