Report discarded zeta predictions and indicate whether they were shown (#42403)

Max Brunsfeld , Michael Sloan , Ben Kunkle , and Agus Zubiaga created

Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <mgsloan@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |  14 +
crates/edit_prediction/src/edit_prediction.rs   |   6 
crates/editor/src/editor.rs                     |   7 
crates/zeta/src/zeta.rs                         | 135 ++++++++++++++++++
4 files changed, 158 insertions(+), 4 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -58,6 +58,9 @@ pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
 /// The name of the header used by the client to indicate that it supports receiving xAI models.
 pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
 
+/// The maximum number of edit predictions that can be rejected per request.
+pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
+
 #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum UsageLimit {
@@ -192,6 +195,17 @@ pub struct AcceptEditPredictionBody {
     pub request_id: String,
 }
 
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct RejectEditPredictionsBody {
+    pub rejections: Vec<EditPredictionRejection>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct EditPredictionRejection {
+    pub request_id: String,
+    pub was_shown: bool,
+}
+
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum CompletionMode {

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -104,6 +104,7 @@ pub trait EditPredictionProvider: 'static + Sized {
     );
     fn accept(&mut self, cx: &mut Context<Self>);
     fn discard(&mut self, cx: &mut Context<Self>);
+    fn did_show(&mut self, _cx: &mut Context<Self>) {}
     fn suggest(
         &mut self,
         buffer: &Entity<Buffer>,
@@ -142,6 +143,7 @@ pub trait EditPredictionProviderHandle {
         direction: Direction,
         cx: &mut App,
     );
+    fn did_show(&self, cx: &mut App);
     fn accept(&self, cx: &mut App);
     fn discard(&self, cx: &mut App);
     fn suggest(
@@ -233,6 +235,10 @@ where
         self.update(cx, |this, cx| this.discard(cx))
     }
 
+    fn did_show(&self, cx: &mut App) {
+        self.update(cx, |this, cx| this.did_show(cx))
+    }
+
     fn suggest(
         &self,
         buffer: &Entity<Buffer>,

crates/editor/src/editor.rs 🔗

@@ -7865,6 +7865,10 @@ impl Editor {
                 self.edit_prediction_preview,
                 EditPredictionPreview::Inactive { .. }
             ) {
+                if let Some(provider) = self.edit_prediction_provider.as_ref() {
+                    provider.provider.did_show(cx)
+                }
+
                 self.edit_prediction_preview = EditPredictionPreview::Active {
                     previous_scroll_position: None,
                     since: Instant::now(),
@@ -8044,6 +8048,9 @@ impl Editor {
                 && !self.edit_predictions_hidden_for_vim_mode;
 
             if show_completions_in_buffer {
+                if let Some(provider) = &self.edit_prediction_provider {
+                    provider.provider.did_show(cx);
+                }
                 if edits
                     .iter()
                     .all(|(range, _)| range.to_offset(&multibuffer).is_empty())

crates/zeta/src/zeta.rs 🔗

@@ -8,7 +8,9 @@ mod rate_completion_modal;
 
 pub(crate) use completion_diff_element::*;
 use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use db::smol::stream::StreamExt as _;
 use edit_prediction::DataCollectionState;
+use futures::channel::mpsc;
 pub use init::*;
 use license_detection::LicenseDetectionWatcher;
 pub use rate_completion_modal::*;
@@ -17,8 +19,10 @@ use anyhow::{Context as _, Result, anyhow};
 use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::{
-    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
-    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
+    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
+    ZED_VERSION_HEADER_NAME,
 };
 use collections::{HashMap, HashSet, VecDeque};
 use futures::AsyncReadExt;
@@ -171,12 +175,15 @@ pub struct Zeta {
     shown_completions: VecDeque<EditPrediction>,
     rated_completions: HashSet<EditPredictionId>,
     data_collection_choice: DataCollectionChoice,
+    discarded_completions: Vec<EditPredictionRejection>,
     llm_token: LlmApiToken,
     _llm_token_subscription: Subscription,
     /// Whether an update to a newer version of Zed is required to continue using Zeta.
     update_required: bool,
     user_store: Entity<UserStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+    discard_completions_debounce_task: Option<Task<()>>,
+    discard_completions_tx: mpsc::UnboundedSender<()>,
 }
 
 struct ZetaProject {
@@ -226,11 +233,25 @@ impl Zeta {
     fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
         let data_collection_choice = Self::load_data_collection_choice();
+        let (reject_tx, mut reject_rx) = mpsc::unbounded();
+        cx.spawn(async move |this, cx| {
+            while let Some(()) = reject_rx.next().await {
+                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
+                    .await
+                    .log_err();
+            }
+            anyhow::Ok(())
+        })
+        .detach();
+
         Self {
             projects: HashMap::default(),
             client,
             shown_completions: VecDeque::new(),
             rated_completions: HashSet::default(),
+            discarded_completions: Vec::new(),
+            discard_completions_debounce_task: None,
+            discard_completions_tx: reject_tx,
             data_collection_choice,
             llm_token: LlmApiToken::default(),
             _llm_token_subscription: cx.subscribe(
@@ -692,6 +713,75 @@ impl Zeta {
         })
     }
 
+    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let client = self.client.clone();
+        let llm_token = self.llm_token.clone();
+        let app_version = AppVersion::global(cx);
+        let last_rejection = self.discarded_completions.last().cloned();
+        let body = serde_json::to_string(&RejectEditPredictionsBody {
+            rejections: self.discarded_completions.clone(),
+        })
+        .ok();
+
+        let Some(last_rejection) = last_rejection else {
+            return Task::ready(anyhow::Ok(()));
+        };
+
+        cx.spawn(async move |this, cx| {
+            let http_client = client.http_client();
+            let mut response = llm_token_retry(&llm_token, &client, |token| {
+                let request_builder = http_client::Request::builder().method(Method::POST);
+                let request_builder = request_builder.uri(
+                    http_client
+                        .build_zed_llm_url("/predict_edits/reject", &[])?
+                        .as_ref(),
+                );
+                Ok(request_builder
+                    .header("Content-Type", "application/json")
+                    .header("Authorization", format!("Bearer {}", token))
+                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+                    .body(
+                        body.as_ref()
+                            .context("failed to serialize body")?
+                            .clone()
+                            .into(),
+                    )?)
+            })
+            .await?;
+
+            if let Some(minimum_required_version) = response
+                .headers()
+                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+                && app_version < minimum_required_version
+            {
+                return Err(anyhow!(ZedUpdateRequiredError {
+                    minimum_version: minimum_required_version
+                }));
+            }
+
+            if response.status().is_success() {
+                this.update(cx, |this, _| {
+                    if let Some(ix) = this
+                        .discarded_completions
+                        .iter()
+                        .position(|rejection| rejection.request_id == last_rejection.request_id)
+                    {
+                        this.discarded_completions.drain(..ix + 1);
+                    }
+                })
+            } else {
+                let mut body = String::new();
+                response.body_mut().read_to_string(&mut body).await?;
+                Err(anyhow!(
+                    "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
+                    response.status(),
+                    body
+                ))
+            }
+        })
+    }
+
     fn process_completion_response(
         prediction_response: PredictEditsResponse,
         buffer: Entity<Buffer>,
@@ -995,6 +1085,31 @@ impl Zeta {
             )
         });
     }
+
+    fn discard_completion(
+        &mut self,
+        completion_id: EditPredictionId,
+        was_shown: bool,
+        cx: &mut Context<Self>,
+    ) {
+        self.discarded_completions.push(EditPredictionRejection {
+            request_id: completion_id.to_string(),
+            was_shown,
+        });
+
+        let reached_request_limit =
+            self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
+        let discard_completions_tx = self.discard_completions_tx.clone();
+        self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
+            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
+            if !reached_request_limit {
+                cx.background_executor()
+                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
+                    .await;
+            }
+            discard_completions_tx.unbounded_send(()).log_err();
+        }));
+    }
 }
 
 pub struct PerformPredictEditsParams {
@@ -1167,6 +1282,7 @@ impl Event {
 struct CurrentEditPrediction {
     buffer_id: EntityId,
     completion: EditPrediction,
+    was_shown: bool,
 }
 
 impl CurrentEditPrediction {
@@ -1414,6 +1530,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
                         c.map(|completion| CurrentEditPrediction {
                             buffer_id: buffer.entity_id(),
                             completion,
+                            was_shown: false,
                         })
                     })
                 }
@@ -1505,9 +1622,19 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
         self.pending_completions.clear();
     }
 
-    fn discard(&mut self, _cx: &mut Context<Self>) {
+    fn discard(&mut self, cx: &mut Context<Self>) {
         self.pending_completions.clear();
-        self.current_completion.take();
+        if let Some(completion) = self.current_completion.take() {
+            self.zeta.update(cx, |zeta, cx| {
+                zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
+            });
+        }
+    }
+
+    fn did_show(&mut self, _cx: &mut Context<Self>) {
+        if let Some(current_completion) = self.current_completion.as_mut() {
+            current_completion.was_shown = true;
+        }
     }
 
     fn suggest(