edit prediction: Report early-rejected predictions and fix cancel bug (#43585)

Agus Zubiaga and MrSubidubi created

Many prediction requests end up being rejected early without ever being
set as the current prediction. Before this change, those cases werenโ€™t
reported as rejections because the `request_prediction_with_*` functions
simply returned `Ok(None)`.

With this update, whenever we get a successful response from the
provider, we will return at least the `id`, allowing it to be properly
reported. The request now also includes a โ€œreject reason,โ€ since the
different variants carry distinct implications for prediction quality.

All of these scenarios are now covered by tests. While adding them, I
also found and fixed a bug where some cancelled predictions were
incorrectly being set as the current one.

Release Notes:

- N/A

---------

Co-authored-by: MrSubidubi <dev@bahn.sh>

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |  21 
crates/zeta/src/prediction.rs                   |  89 +
crates/zeta/src/provider.rs                     |  13 
crates/zeta/src/sweep_ai.rs                     |  10 
crates/zeta/src/zeta.rs                         | 831 ++++++++++++++++--
crates/zeta/src/zeta1.rs                        |  13 
crates/zeta/src/zeta_tests.rs                   |   2 
crates/zeta_cli/src/predict.rs                  |   5 
8 files changed, 815 insertions(+), 169 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs ๐Ÿ”—

@@ -200,12 +200,31 @@ pub struct RejectEditPredictionsBody {
     pub rejections: Vec<EditPredictionRejection>,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct EditPredictionRejection {
     pub request_id: String,
+    #[serde(default)]
+    pub reason: EditPredictionRejectReason,
     pub was_shown: bool,
 }
 
+#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
+pub enum EditPredictionRejectReason {
+    /// New requests were triggered before this one completed
+    Canceled,
+    /// No edits returned
+    Empty,
+    /// Edits returned, but none remained after interpolation
+    InterpolatedEmpty,
+    /// The new prediction was preferred over the current one
+    Replaced,
+    /// The current prediction was preferred over the new one
+    CurrentPreferred,
+    /// The current prediction was discarded
+    #[default]
+    Discarded,
+}
+
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum CompletionMode {

crates/zeta/src/prediction.rs ๐Ÿ”—

@@ -5,6 +5,7 @@ use std::{
     time::{Duration, Instant},
 };
 
+use cloud_llm_client::EditPredictionRejectReason;
 use gpui::{AsyncApp, Entity, SharedString};
 use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
 use serde::Serialize;
@@ -24,28 +25,13 @@ impl std::fmt::Display for EditPredictionId {
     }
 }
 
-#[derive(Clone)]
-pub struct EditPrediction {
+/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
+pub struct EditPredictionResult {
     pub id: EditPredictionId,
-    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
-    pub snapshot: BufferSnapshot,
-    pub edit_preview: EditPreview,
-    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
-    pub buffer: Entity<Buffer>,
-    pub buffer_snapshotted_at: Instant,
-    pub response_received_at: Instant,
-    pub inputs: EditPredictionInputs,
-}
-
-#[derive(Debug, Clone, Serialize)]
-pub struct EditPredictionInputs {
-    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
-    pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
-    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
-    pub cursor_path: Arc<Path>,
+    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 }
 
-impl EditPrediction {
+impl EditPredictionResult {
     pub async fn new(
         id: EditPredictionId,
         edited_buffer: &Entity<Buffer>,
@@ -55,8 +41,15 @@ impl EditPrediction {
         response_received_at: Instant,
         inputs: EditPredictionInputs,
         cx: &mut AsyncApp,
-    ) -> Option<Self> {
-        let (edits, snapshot, edit_preview_task) = edited_buffer
+    ) -> Self {
+        if edits.is_empty() {
+            return Self {
+                id,
+                prediction: Err(EditPredictionRejectReason::Empty),
+            };
+        }
+
+        let Some((edits, snapshot, edit_preview_task)) = edited_buffer
             .read_with(cx, |buffer, cx| {
                 let new_snapshot = buffer.snapshot();
                 let edits: Arc<[_]> =
@@ -64,22 +57,54 @@ impl EditPrediction {
 
                 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
             })
-            .ok()??;
+            .ok()
+            .flatten()
+        else {
+            return Self {
+                id,
+                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
+            };
+        };
 
         let edit_preview = edit_preview_task.await;
 
-        Some(EditPrediction {
-            id,
-            edits,
-            snapshot,
-            edit_preview,
-            inputs,
-            buffer: edited_buffer.clone(),
-            buffer_snapshotted_at,
-            response_received_at,
-        })
+        Self {
+            id: id.clone(),
+            prediction: Ok(EditPrediction {
+                id,
+                edits,
+                snapshot,
+                edit_preview,
+                inputs,
+                buffer: edited_buffer.clone(),
+                buffer_snapshotted_at,
+                response_received_at,
+            }),
+        }
     }
+}
 
+#[derive(Clone)]
+pub struct EditPrediction {
+    pub id: EditPredictionId,
+    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+    pub snapshot: BufferSnapshot,
+    pub edit_preview: EditPreview,
+    pub buffer: Entity<Buffer>,
+    pub buffer_snapshotted_at: Instant,
+    pub response_received_at: Instant,
+    pub inputs: EditPredictionInputs,
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct EditPredictionInputs {
+    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+    pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
+    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
+    pub cursor_path: Arc<Path>,
+}
+
+impl EditPrediction {
     pub fn interpolate(
         &self,
         new_snapshot: &TextBufferSnapshot,

crates/zeta/src/provider.rs ๐Ÿ”—

@@ -1,6 +1,7 @@
 use std::{cmp, sync::Arc, time::Duration};
 
 use client::{Client, UserStore};
+use cloud_llm_client::EditPredictionRejectReason;
 use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
 use gpui::{App, Entity, prelude::*};
 use language::ToPoint as _;
@@ -132,7 +133,11 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
 
     fn discard(&mut self, cx: &mut Context<Self>) {
         self.zeta.update(cx, |zeta, cx| {
-            zeta.discard_current_prediction(&self.project, cx);
+            zeta.reject_current_prediction(
+                EditPredictionRejectReason::Discarded,
+                &self.project,
+                cx,
+            );
         });
     }
 
@@ -169,7 +174,11 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
 
         let Some(edits) = prediction.interpolate(&snapshot) else {
             self.zeta.update(cx, |zeta, cx| {
-                zeta.discard_current_prediction(&self.project, cx);
+                zeta.reject_current_prediction(
+                    EditPredictionRejectReason::InterpolatedEmpty,
+                    &self.project,
+                    cx,
+                );
             });
             return None;
         };

crates/zeta/src/sweep_ai.rs ๐Ÿ”—

@@ -18,7 +18,7 @@ use std::{
     time::Instant,
 };
 
-use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
+use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
 
 const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 
@@ -45,7 +45,7 @@ impl SweepAi {
         recent_paths: &VecDeque<ProjectPath>,
         diagnostic_search_range: Range<Point>,
         cx: &mut App,
-    ) -> Task<Result<Option<EditPrediction>>> {
+    ) -> Task<Result<Option<EditPredictionResult>>> {
         let debug_info = self.debug_info.clone();
         let Some(api_token) = self.api_token.clone() else {
             return Task::ready(Ok(None));
@@ -242,8 +242,8 @@ impl SweepAi {
 
         cx.spawn(async move |cx| {
             let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
-            anyhow::Ok(
-                EditPrediction::new(
+            anyhow::Ok(Some(
+                EditPredictionResult::new(
                     EditPredictionId(id.into()),
                     &buffer,
                     &old_snapshot,
@@ -254,7 +254,7 @@ impl SweepAi {
                     cx,
                 )
                 .await,
-            )
+            ))
         })
     }
 }

crates/zeta/src/zeta.rs ๐Ÿ”—

@@ -3,9 +3,9 @@ use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
 use cloud_llm_client::{
-    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
-    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
-    RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
+    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
+    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
+    MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
 };
 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
 use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
@@ -74,6 +74,7 @@ use crate::onboarding_modal::ZedPredictModal;
 pub use crate::prediction::EditPrediction;
 pub use crate::prediction::EditPredictionId;
 pub use crate::prediction::EditPredictionInputs;
+use crate::prediction::EditPredictionResult;
 use crate::rate_prediction_modal::{
     NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
     ThumbsUpActivePrediction,
@@ -310,6 +311,31 @@ impl ZetaProject {
             )
             .collect()
     }
+
+    fn cancel_pending_prediction(
+        &mut self,
+        pending_prediction: PendingPrediction,
+        cx: &mut Context<Zeta>,
+    ) {
+        self.cancelled_predictions.insert(pending_prediction.id);
+
+        cx.spawn(async move |this, cx| {
+            let Some(prediction_id) = pending_prediction.task.await else {
+                return;
+            };
+
+            this.update(cx, |this, cx| {
+                this.reject_prediction(
+                    prediction_id,
+                    EditPredictionRejectReason::Canceled,
+                    false,
+                    cx,
+                );
+            })
+            .ok();
+        })
+        .detach()
+    }
 }
 
 #[derive(Debug, Clone)]
@@ -373,6 +399,7 @@ impl PredictionRequestedBy {
     }
 }
 
+#[derive(Debug)]
 struct PendingPrediction {
     id: usize,
     task: Task<Option<EditPredictionId>>,
@@ -385,6 +412,18 @@ enum BufferEditPrediction<'a> {
     Jump { prediction: &'a EditPrediction },
 }
 
+#[cfg(test)]
+impl std::ops::Deref for BufferEditPrediction<'_> {
+    type Target = EditPrediction;
+
+    fn deref(&self) -> &Self::Target {
+        match self {
+            BufferEditPrediction::Local { prediction } => prediction,
+            BufferEditPrediction::Jump { prediction } => prediction,
+        }
+    }
+}
+
 struct RegisteredBuffer {
     snapshot: BufferSnapshot,
     _subscriptions: [gpui::Subscription; 2],
@@ -467,7 +506,7 @@ impl Zeta {
         let (reject_tx, mut reject_rx) = mpsc::unbounded();
         cx.spawn(async move |this, cx| {
             while let Some(()) = reject_rx.next().await {
-                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
+                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
                     .await
                     .log_err();
             }
@@ -818,7 +857,7 @@ impl Zeta {
         };
         let request_id = prediction.prediction.id.to_string();
         for pending_prediction in mem::take(&mut project_state.pending_predictions) {
-            self.cancel_pending_prediction(pending_prediction, cx);
+            project_state.cancel_pending_prediction(pending_prediction, cx);
         }
 
         let client = self.client.clone();
@@ -856,7 +895,7 @@ impl Zeta {
         .detach_and_log_err(cx);
     }
 
-    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
         match self.edit_prediction_model {
             ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
             ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
@@ -904,11 +943,16 @@ impl Zeta {
         })
     }
 
-    fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+    fn reject_current_prediction(
+        &mut self,
+        reason: EditPredictionRejectReason,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) {
         if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
             project_state.pending_predictions.clear();
             if let Some(prediction) = project_state.current_prediction.take() {
-                self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
+                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
             }
         };
     }
@@ -929,14 +973,16 @@ impl Zeta {
         }
     }
 
-    fn discard_prediction(
+    fn reject_prediction(
         &mut self,
         prediction_id: EditPredictionId,
+        reason: EditPredictionRejectReason,
         was_shown: bool,
         cx: &mut Context<Self>,
     ) {
         self.rejected_predictions.push(EditPredictionRejection {
             request_id: prediction_id.to_string(),
+            reason,
             was_shown,
         });
 
@@ -944,34 +990,16 @@ impl Zeta {
             self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
         let reject_tx = self.reject_predictions_tx.clone();
         self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
-            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
+            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
             if !reached_request_limit {
                 cx.background_executor()
-                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
+                    .timer(REJECT_REQUEST_DEBOUNCE)
                     .await;
             }
             reject_tx.unbounded_send(()).log_err();
         }));
     }
 
-    fn cancel_pending_prediction(
-        &self,
-        pending_prediction: PendingPrediction,
-        cx: &mut Context<Self>,
-    ) {
-        cx.spawn(async move |this, cx| {
-            let Some(prediction_id) = pending_prediction.task.await else {
-                return;
-            };
-
-            this.update(cx, |this, cx| {
-                this.discard_prediction(prediction_id, false, cx);
-            })
-            .ok();
-        })
-        .detach()
-    }
-
     fn is_refreshing(&self, project: &Entity<Project>) -> bool {
         self.projects
             .get(&project.entity_id())
@@ -995,38 +1023,15 @@ impl Zeta {
                 return Task::ready(anyhow::Ok(None));
             };
 
-            let project = project.clone();
-            cx.spawn(async move |cx| {
-                if let Some(prediction) = request_task.await? {
-                    let id = prediction.id.clone();
-                    this.update(cx, |this, cx| {
-                        let project_state = this
-                            .projects
-                            .get_mut(&project.entity_id())
-                            .context("Project not found")?;
-
-                        let new_prediction = CurrentEditPrediction {
-                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
-                            prediction: prediction,
-                            was_shown: false,
-                        };
-
-                        if project_state
-                            .current_prediction
-                            .as_ref()
-                            .is_none_or(|old_prediction| {
-                                new_prediction.should_replace_prediction(&old_prediction, cx)
-                            })
-                        {
-                            project_state.current_prediction = Some(new_prediction);
-                            cx.notify();
-                        }
-                        anyhow::Ok(())
-                    })??;
-                    Ok(Some(id))
-                } else {
-                    Ok(None)
-                }
+            cx.spawn(async move |_cx| {
+                request_task.await.map(|prediction_result| {
+                    prediction_result.map(|prediction_result| {
+                        (
+                            prediction_result,
+                            PredictionRequestedBy::Buffer(buffer.entity_id()),
+                        )
+                    })
+                })
             })
         })
     }
@@ -1076,7 +1081,7 @@ impl Zeta {
                     return anyhow::Ok(None);
                 };
 
-                let Some(prediction) = this
+                let Some(prediction_result) = this
                     .update(cx, |this, cx| {
                         this.request_prediction(&project, &jump_buffer, jump_position, cx)
                     })?
@@ -1085,21 +1090,23 @@ impl Zeta {
                     return anyhow::Ok(None);
                 };
 
-                let id = prediction.id.clone();
                 this.update(cx, |this, cx| {
-                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
-                        zeta_project.current_prediction.get_or_insert_with(|| {
-                            cx.notify();
-                            CurrentEditPrediction {
-                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
-                                prediction,
-                                was_shown: false,
+                    Some((
+                        if this
+                            .get_or_init_zeta_project(&project, cx)
+                            .current_prediction
+                            .is_none()
+                        {
+                            prediction_result
+                        } else {
+                            EditPredictionResult {
+                                id: prediction_result.id,
+                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
                             }
-                        });
-                    }
-                })?;
-
-                anyhow::Ok(Some(id))
+                        },
+                        PredictionRequestedBy::DiagnosticsUpdate,
+                    ))
+                })
             })
         });
     }
@@ -1117,7 +1124,8 @@ impl Zeta {
         do_refresh: impl FnOnce(
             WeakEntity<Self>,
             &mut AsyncApp,
-        ) -> Task<Result<Option<EditPredictionId>>>
+        )
+            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
         + 'static,
     ) {
         let zeta_project = self.get_or_init_zeta_project(&project, cx);
@@ -1152,22 +1160,77 @@ impl Zeta {
                 return None;
             }
 
-            let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
+            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
+            let new_prediction_id = new_prediction_result
+                .as_ref()
+                .map(|(prediction, _)| prediction.id.clone());
 
             // When a prediction completes, remove it from the pending list, and cancel
             // any pending predictions that were enqueued before it.
             this.update(cx, |this, cx| {
                 let zeta_project = this.get_or_init_zeta_project(&project, cx);
-                zeta_project
+
+                let is_cancelled = zeta_project
                     .cancelled_predictions
                     .remove(&pending_prediction_id);
 
+                let new_current_prediction = if !is_cancelled
+                    && let Some((prediction_result, requested_by)) = new_prediction_result
+                {
+                    match prediction_result.prediction {
+                        Ok(prediction) => {
+                            let new_prediction = CurrentEditPrediction {
+                                requested_by,
+                                prediction,
+                                was_shown: false,
+                            };
+
+                            if let Some(current_prediction) =
+                                zeta_project.current_prediction.as_ref()
+                            {
+                                if new_prediction.should_replace_prediction(&current_prediction, cx)
+                                {
+                                    this.reject_current_prediction(
+                                        EditPredictionRejectReason::Replaced,
+                                        &project,
+                                        cx,
+                                    );
+
+                                    Some(new_prediction)
+                                } else {
+                                    this.reject_prediction(
+                                        new_prediction.prediction.id,
+                                        EditPredictionRejectReason::CurrentPreferred,
+                                        false,
+                                        cx,
+                                    );
+                                    None
+                                }
+                            } else {
+                                Some(new_prediction)
+                            }
+                        }
+                        Err(reject_reason) => {
+                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+                            None
+                        }
+                    }
+                } else {
+                    None
+                };
+
+                let zeta_project = this.get_or_init_zeta_project(&project, cx);
+
+                if let Some(new_prediction) = new_current_prediction {
+                    zeta_project.current_prediction = Some(new_prediction);
+                }
+
                 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
                 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
                     if pending_prediction.id == pending_prediction_id {
                         pending_predictions.remove(ix);
                         for pending_prediction in pending_predictions.drain(0..ix) {
-                            this.cancel_pending_prediction(pending_prediction, cx)
+                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
                         }
                         break;
                     }
@@ -1178,7 +1241,7 @@ impl Zeta {
             })
             .ok();
 
-            edit_prediction_id
+            new_prediction_id
         });
 
         if zeta_project.pending_predictions.len() <= 1 {
@@ -1192,10 +1255,7 @@ impl Zeta {
                 id: pending_prediction_id,
                 task,
             });
-            zeta_project
-                .cancelled_predictions
-                .insert(pending_prediction.id);
-            self.cancel_pending_prediction(pending_prediction, cx);
+            zeta_project.cancel_pending_prediction(pending_prediction, cx);
         }
     }
 
@@ -1205,7 +1265,7 @@ impl Zeta {
         active_buffer: &Entity<Buffer>,
         position: language::Anchor,
         cx: &mut Context<Self>,
-    ) -> Task<Result<Option<EditPrediction>>> {
+    ) -> Task<Result<Option<EditPredictionResult>>> {
         self.request_prediction_internal(
             project.clone(),
             active_buffer.clone(),
@@ -1222,7 +1282,7 @@ impl Zeta {
         position: language::Anchor,
         allow_jump: bool,
         cx: &mut Context<Self>,
-    ) -> Task<Result<Option<EditPrediction>>> {
+    ) -> Task<Result<Option<EditPredictionResult>>> {
         const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 
         self.get_or_init_zeta_project(&project, cx);
@@ -1268,9 +1328,7 @@ impl Zeta {
         };
 
         cx.spawn(async move |this, cx| {
-            let prediction = task
-                .await?
-                .filter(|prediction| !prediction.edits.is_empty());
+            let prediction = task.await?;
 
             if prediction.is_none() && allow_jump {
                 let cursor_point = position.to_point(&snapshot);
@@ -1392,7 +1450,7 @@ impl Zeta {
         position: language::Anchor,
         events: Vec<Arc<Event>>,
         cx: &mut Context<Self>,
-    ) -> Task<Result<Option<EditPrediction>>> {
+    ) -> Task<Result<Option<EditPredictionResult>>> {
         let project_state = self.projects.get(&project.entity_id());
 
         let index_state = project_state.and_then(|state| {
@@ -1689,7 +1747,7 @@ impl Zeta {
                 let (res, usage) = response?;
                 let request_id = EditPredictionId(res.id.clone().into());
                 let Some(mut output_text) = text_from_response(res) else {
-                    return Ok((None, usage));
+                    return Ok((Some((request_id, None)), usage));
                 };
 
                 if output_text.contains(CURSOR_MARKER) {
@@ -1747,11 +1805,13 @@ impl Zeta {
                 anyhow::Ok((
                     Some((
                         request_id,
-                        inputs,
-                        edited_buffer,
-                        edited_buffer_snapshot.clone(),
-                        edits,
-                        received_response_at,
+                        Some((
+                            inputs,
+                            edited_buffer,
+                            edited_buffer_snapshot.clone(),
+                            edits,
+                            received_response_at,
+                        )),
                     )),
                     usage,
                 ))
@@ -1760,30 +1820,40 @@ impl Zeta {
 
         cx.spawn({
             async move |this, cx| {
+                let Some((id, prediction)) =
+                    Self::handle_api_response(&this, request_task.await, cx)?
+                else {
+                    return Ok(None);
+                };
+
                 let Some((
-                    id,
                     inputs,
                     edited_buffer,
                     edited_buffer_snapshot,
                     edits,
                     received_response_at,
-                )) = Self::handle_api_response(&this, request_task.await, cx)?
+                )) = prediction
                 else {
-                    return Ok(None);
+                    return Ok(Some(EditPredictionResult {
+                        id,
+                        prediction: Err(EditPredictionRejectReason::Empty),
+                    }));
                 };
 
                 // TODO telemetry: duration, etc
-                Ok(EditPrediction::new(
-                    id,
-                    &edited_buffer,
-                    &edited_buffer_snapshot,
-                    edits.into(),
-                    buffer_snapshotted_at,
-                    received_response_at,
-                    inputs,
-                    cx,
-                )
-                .await)
+                Ok(Some(
+                    EditPredictionResult::new(
+                        id,
+                        &edited_buffer,
+                        &edited_buffer_snapshot,
+                        edits.into(),
+                        buffer_snapshotted_at,
+                        received_response_at,
+                        inputs,
+                        cx,
+                    )
+                    .await,
+                ))
             }
         })
     }
@@ -2806,6 +2876,9 @@ mod tests {
 
     use client::UserStore;
     use clock::FakeSystemClock;
+    use cloud_llm_client::{
+        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
+    };
     use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
     use futures::{
         AsyncReadExt, StreamExt,
@@ -2830,7 +2903,7 @@ mod tests {
 
     #[gpui::test]
     async fn test_current_state(cx: &mut TestAppContext) {
-        let (zeta, mut req_rx) = init_test(cx);
+        let (zeta, mut requests) = init_test(cx);
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
             "/root",
@@ -2861,7 +2934,7 @@ mod tests {
         zeta.update(cx, |zeta, cx| {
             zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
         });
-        let (_request, respond_tx) = req_rx.next().await.unwrap();
+        let (_request, respond_tx) = requests.predict.next().await.unwrap();
 
         respond_tx
             .send(model_response(indoc! {r"
@@ -2888,7 +2961,7 @@ mod tests {
         let refresh_task = zeta.update(cx, |zeta, cx| {
             zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
         });
-        let (_request, respond_tx) = req_rx.next().await.unwrap();
+        let (_request, respond_tx) = requests.predict.next().await.unwrap();
         respond_tx
             .send(open_ai::Response {
                 id: Uuid::new_v4().to_string(),
@@ -2929,14 +3002,14 @@ mod tests {
         refresh_task.await.unwrap();
 
         zeta.update(cx, |zeta, cx| {
-            zeta.discard_current_prediction(&project, cx);
+            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
         });
 
         // Prediction for another file
         zeta.update(cx, |zeta, cx| {
             zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
         });
-        let (_request, respond_tx) = req_rx.next().await.unwrap();
+        let (_request, respond_tx) = requests.predict.next().await.unwrap();
         respond_tx
             .send(model_response(indoc! {r#"
                 --- a/root/2.txt
@@ -2977,7 +3050,7 @@ mod tests {
 
     #[gpui::test]
     async fn test_simple_request(cx: &mut TestAppContext) {
-        let (zeta, mut req_rx) = init_test(cx);
+        let (zeta, mut requests) = init_test(cx);
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
             "/root",
@@ -3002,7 +3075,7 @@ mod tests {
             zeta.request_prediction(&project, &buffer, position, cx)
         });
 
-        let (_, respond_tx) = req_rx.next().await.unwrap();
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
 
         // TODO Put back when we have a structured request again
         // assert_eq!(
@@ -3029,7 +3102,7 @@ mod tests {
             "}))
             .unwrap();
 
-        let prediction = prediction_task.await.unwrap().unwrap();
+        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 
         assert_eq!(prediction.edits.len(), 1);
         assert_eq!(
@@ -3041,7 +3114,7 @@ mod tests {
 
     #[gpui::test]
     async fn test_request_events(cx: &mut TestAppContext) {
-        let (zeta, mut req_rx) = init_test(cx);
+        let (zeta, mut requests) = init_test(cx);
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
             "/root",
@@ -3075,7 +3148,7 @@ mod tests {
             zeta.request_prediction(&project, &buffer, position, cx)
         });
 
-        let (request, respond_tx) = req_rx.next().await.unwrap();
+        let (request, respond_tx) = requests.predict.next().await.unwrap();
 
         let prompt = prompt_from_request(&request);
         assert!(
@@ -3103,7 +3176,7 @@ mod tests {
             "#}))
             .unwrap();
 
-        let prediction = prediction_task.await.unwrap().unwrap();
+        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 
         assert_eq!(prediction.edits.len(), 1);
         assert_eq!(
@@ -3113,6 +3186,522 @@ mod tests {
         assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
     }
 
+    #[gpui::test]
+    async fn test_empty_prediction(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        const NO_OP_DIFF: &str = indoc! { r"
+            --- a/root/foo.md
+            +++ b/root/foo.md
+            @@ ... @@
+             Hello!
+            -How
+            +How
+             Bye
+        "};
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+        let response = model_response(NO_OP_DIFF);
+        let id = response.id.clone();
+        respond_tx.send(response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            assert!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .is_none()
+            );
+        });
+
+        // prediction is reported as rejected
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[EditPredictionRejection {
+                request_id: id,
+                reason: EditPredictionRejectReason::Empty,
+                was_shown: false
+            }]
+        );
+    }
+
+    #[gpui::test]
+    async fn test_interpolated_empty(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.set_text("Hello!\nHow are you?\nBye", cx);
+        });
+
+        let response = model_response(SIMPLE_DIFF);
+        let id = response.id.clone();
+        respond_tx.send(response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            assert!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .is_none()
+            );
+        });
+
+        // prediction is reported as rejected
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[EditPredictionRejection {
+                request_id: id,
+                reason: EditPredictionRejectReason::InterpolatedEmpty,
+                was_shown: false
+            }]
+        );
+    }
+
+    const SIMPLE_DIFF: &str = indoc! { r"
+        --- a/root/foo.md
+        +++ b/root/foo.md
+        @@ ... @@
+         Hello!
+        -How
+        +How are you?
+         Bye
+    "};
+
+    #[gpui::test]
+    async fn test_replace_current(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+        let first_response = model_response(SIMPLE_DIFF);
+        let first_id = first_response.id.clone();
+        respond_tx.send(first_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                first_id
+            );
+        });
+
+        // a second request is triggered
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+        let second_response = model_response(SIMPLE_DIFF);
+        let second_id = second_response.id.clone();
+        respond_tx.send(second_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // second replaces first
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                second_id
+            );
+        });
+
+        // first is reported as replaced
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[EditPredictionRejection {
+                request_id: first_id,
+                reason: EditPredictionRejectReason::Replaced,
+                was_shown: false
+            }]
+        );
+    }
+
+    #[gpui::test]
+    async fn test_current_preferred(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+        let first_response = model_response(SIMPLE_DIFF);
+        let first_id = first_response.id.clone();
+        respond_tx.send(first_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                first_id
+            );
+        });
+
+        // a second request is triggered
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_tx) = requests.predict.next().await.unwrap();
+        // worse than current prediction
+        let second_response = model_response(indoc! { r"
+            --- a/root/foo.md
+            +++ b/root/foo.md
+            @@ ... @@
+             Hello!
+            -How
+            +How are
+             Bye
+        "});
+        let second_id = second_response.id.clone();
+        respond_tx.send(second_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // first is preferred over second
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                first_id
+            );
+        });
+
+        // second is reported as rejected
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[EditPredictionRejection {
+                request_id: second_id,
+                reason: EditPredictionRejectReason::CurrentPreferred,
+                was_shown: false
+            }]
+        );
+    }
+
+    #[gpui::test]
+    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            // start two refresh tasks
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_first) = requests.predict.next().await.unwrap();
+        let (_, respond_second) = requests.predict.next().await.unwrap();
+
+        // wait for throttle
+        cx.run_until_parked();
+
+        // second responds first
+        let second_response = model_response(SIMPLE_DIFF);
+        let second_id = second_response.id.clone();
+        respond_second.send(second_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // current prediction is second
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                second_id
+            );
+        });
+
+        let first_response = model_response(SIMPLE_DIFF);
+        let first_id = first_response.id.clone();
+        respond_first.send(first_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // current prediction is still second, since first was cancelled
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                second_id
+            );
+        });
+
+        // first is reported as rejected
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[EditPredictionRejection {
+                request_id: first_id,
+                reason: EditPredictionRejectReason::Canceled,
+                was_shown: false
+            }]
+        );
+    }
+
+    #[gpui::test]
+    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        zeta.update(cx, |zeta, cx| {
+            // start two refresh tasks
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        // wait for throttle, so requests are sent
+        cx.run_until_parked();
+
+        let (_, respond_first) = requests.predict.next().await.unwrap();
+        let (_, respond_second) = requests.predict.next().await.unwrap();
+
+        zeta.update(cx, |zeta, cx| {
+            // start a third request
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+            // 2 are pending, so 2nd is cancelled
+            assert_eq!(
+                zeta.get_or_init_zeta_project(&project, cx)
+                    .cancelled_predictions
+                    .iter()
+                    .copied()
+                    .collect::<Vec<_>>(),
+                [1]
+            );
+        });
+
+        // wait for throttle
+        cx.run_until_parked();
+
+        let (_, respond_third) = requests.predict.next().await.unwrap();
+
+        let first_response = model_response(SIMPLE_DIFF);
+        let first_id = first_response.id.clone();
+        respond_first.send(first_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // current prediction is first
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                first_id
+            );
+        });
+
+        let cancelled_response = model_response(SIMPLE_DIFF);
+        let cancelled_id = cancelled_response.id.clone();
+        respond_second.send(cancelled_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // current prediction is still first, since second was cancelled
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                first_id
+            );
+        });
+
+        let third_response = model_response(SIMPLE_DIFF);
+        let third_response_id = third_response.id.clone();
+        respond_third.send(third_response).unwrap();
+
+        cx.run_until_parked();
+
+        zeta.read_with(cx, |zeta, cx| {
+            // third completes and replaces first
+            assert_eq!(
+                zeta.current_prediction_for_buffer(&buffer, &project, cx)
+                    .unwrap()
+                    .id
+                    .0,
+                third_response_id
+            );
+        });
+
+        // second is reported as rejected
+        let (reject_request, _) = requests.reject.next().await.unwrap();
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            &reject_request.rejections,
+            &[
+                EditPredictionRejection {
+                    request_id: cancelled_id,
+                    reason: EditPredictionRejectReason::Canceled,
+                    was_shown: false
+                },
+                EditPredictionRejection {
+                    request_id: first_id,
+                    reason: EditPredictionRejectReason::Replaced,
+                    was_shown: false
+                }
+            ]
+        );
+    }
+
     // Skipped until we start including diagnostics in prompt
     // #[gpui::test]
     // async fn test_request_diagnostics(cx: &mut TestAppContext) {

crates/zeta/src/zeta1.rs ๐Ÿ”—

@@ -4,7 +4,7 @@ use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
 
 use crate::{
     EditPredictionId, ZedUpdateRequiredError, Zeta,
-    prediction::{EditPrediction, EditPredictionInputs},
+    prediction::{EditPredictionInputs, EditPredictionResult},
 };
 use anyhow::{Context as _, Result};
 use cloud_llm_client::{
@@ -36,7 +36,7 @@ pub(crate) fn request_prediction_with_zeta1(
     position: language::Anchor,
     events: Vec<Arc<Event>>,
     cx: &mut Context<Zeta>,
-) -> Task<Result<Option<EditPrediction>>> {
+) -> Task<Result<Option<EditPredictionResult>>> {
     let buffer = buffer.clone();
     let buffer_snapshotted_at = Instant::now();
     let client = zeta.client.clone();
@@ -216,7 +216,7 @@ pub(crate) fn request_prediction_with_zeta1(
             );
         }
 
-        edit_prediction
+        edit_prediction.map(Some)
     })
 }
 
@@ -229,7 +229,7 @@ fn process_completion_response(
     buffer_snapshotted_at: Instant,
     received_response_at: Instant,
     cx: &AsyncApp,
-) -> Task<Result<Option<EditPrediction>>> {
+) -> Task<Result<EditPredictionResult>> {
     let snapshot = snapshot.clone();
     let request_id = prediction_response.request_id;
     let output_excerpt = prediction_response.output_excerpt;
@@ -246,8 +246,9 @@ fn process_completion_response(
             .await?
             .into();
 
-        Ok(EditPrediction::new(
-            EditPredictionId(request_id.into()),
+        let id = EditPredictionId(request_id.into());
+        Ok(EditPredictionResult::new(
+            id,
             &buffer,
             &snapshot,
             edits,

crates/zeta/src/zeta_tests.rs ๐Ÿ”—

@@ -538,7 +538,7 @@ async fn run_edit_prediction(
     let prediction_task = zeta.update(cx, |zeta, cx| {
         zeta.request_prediction(&project, buffer, cursor, cx)
     });
-    prediction_task.await.unwrap().unwrap()
+    prediction_task.await.unwrap().unwrap().prediction.unwrap()
 }
 
 async fn make_test_zeta(

crates/zeta_cli/src/predict.rs ๐Ÿ”—

@@ -235,7 +235,10 @@ pub async fn perform_predict(
     let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
 
     result.diff = prediction
-        .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits))
+        .and_then(|prediction| {
+            let prediction = prediction.prediction.ok()?;
+            prediction.edit_preview.as_unified_diff(&prediction.edits)
+        })
         .unwrap_or_default();
 
     anyhow::Ok(result)