Limit edit prediction reject batches to max (#43965)

Agus Zubiaga created

We currently attempt to flush all rejected predictions at once even if
we have accumulated more than
`MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST`. Instead, we will now flush
as many as possible, and then keep the rest for the next batch.


Release Notes:

- N/A

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |   7 
crates/zeta/src/provider.rs                     |  11 
crates/zeta/src/zeta.rs                         | 306 +++++++++++++-----
3 files changed, 221 insertions(+), 103 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -206,11 +206,16 @@ pub struct AcceptEditPredictionBody {
     pub request_id: String,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Deserialize)]
 pub struct RejectEditPredictionsBody {
     pub rejections: Vec<EditPredictionRejection>,
 }
 
+#[derive(Debug, Clone, Serialize)]
+pub struct RejectEditPredictionsBodyRef<'a> {
+    pub rejections: &'a [EditPredictionRejection],
+}
+
 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct EditPredictionRejection {
     pub request_id: String,

crates/zeta/src/provider.rs 🔗

@@ -132,12 +132,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
     }
 
     fn discard(&mut self, cx: &mut Context<Self>) {
-        self.zeta.update(cx, |zeta, cx| {
-            zeta.reject_current_prediction(
-                EditPredictionRejectReason::Discarded,
-                &self.project,
-                cx,
-            );
+        self.zeta.update(cx, |zeta, _cx| {
+            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
         });
     }
 
@@ -173,11 +169,10 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         let snapshot = buffer.snapshot();
 
         let Some(edits) = prediction.interpolate(&snapshot) else {
-            self.zeta.update(cx, |zeta, cx| {
+            self.zeta.update(cx, |zeta, _cx| {
                 zeta.reject_current_prediction(
                     EditPredictionRejectReason::InterpolatedEmpty,
                     &self.project,
-                    cx,
                 );
             });
             return None;

crates/zeta/src/zeta.rs 🔗

@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
 use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
     EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
-    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
+    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
     ZED_VERSION_HEADER_NAME,
 };
 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
@@ -19,8 +19,10 @@ use edit_prediction_context::{
     SyntaxIndex, SyntaxIndexState,
 };
 use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
+use futures::channel::mpsc::UnboundedReceiver;
 use futures::channel::{mpsc, oneshot};
-use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
+use gpui::BackgroundExecutor;
 use gpui::{
     App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
     http_client::{self, AsyncBody, Method},
@@ -100,6 +102,7 @@ actions!(
 const EVENT_COUNT_MAX: usize = 6;
 const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
+const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 
 pub struct SweepFeatureFlag;
 
@@ -195,9 +198,7 @@ pub struct Zeta {
     edit_prediction_model: ZetaEditPredictionModel,
     pub sweep_ai: SweepAi,
     data_collection_choice: DataCollectionChoice,
-    rejected_predictions: Vec<EditPredictionRejection>,
-    reject_predictions_tx: mpsc::UnboundedSender<()>,
-    reject_predictions_debounce_task: Option<Task<()>>,
+    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
     shown_predictions: VecDeque<EditPrediction>,
     rated_predictions: HashSet<EditPredictionId>,
 }
@@ -325,13 +326,8 @@ impl ZetaProject {
                 return;
             };
 
-            this.update(cx, |this, cx| {
-                this.reject_prediction(
-                    prediction_id,
-                    EditPredictionRejectReason::Canceled,
-                    false,
-                    cx,
-                );
+            this.update(cx, |this, _cx| {
+                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
             })
             .ok();
         })
@@ -504,14 +500,24 @@ impl Zeta {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
         let data_collection_choice = Self::load_data_collection_choice();
 
-        let (reject_tx, mut reject_rx) = mpsc::unbounded();
-        cx.spawn(async move |this, cx| {
-            while let Some(()) = reject_rx.next().await {
-                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
-                    .await
-                    .log_err();
+        let llm_token = LlmApiToken::default();
+
+        let (reject_tx, reject_rx) = mpsc::unbounded();
+        cx.background_spawn({
+            let client = client.clone();
+            let llm_token = llm_token.clone();
+            let app_version = AppVersion::global(cx);
+            let background_executor = cx.background_executor().clone();
+            async move {
+                Self::handle_rejected_predictions(
+                    reject_rx,
+                    client,
+                    llm_token,
+                    app_version,
+                    background_executor,
+                )
+                .await
             }
-            anyhow::Ok(())
         })
         .detach();
 
@@ -520,7 +526,7 @@ impl Zeta {
             client,
             user_store,
             options: DEFAULT_OPTIONS,
-            llm_token: LlmApiToken::default(),
+            llm_token,
             _llm_token_subscription: cx.subscribe(
                 &refresh_llm_token_listener,
                 |this, _listener, _event, cx| {
@@ -540,8 +546,6 @@ impl Zeta {
             edit_prediction_model: ZetaEditPredictionModel::Zeta2,
             sweep_ai: SweepAi::new(cx),
             data_collection_choice,
-            rejected_predictions: Vec::new(),
-            reject_predictions_debounce_task: None,
             reject_predictions_tx: reject_tx,
             rated_predictions: Default::default(),
             shown_predictions: Default::default(),
@@ -901,64 +905,73 @@ impl Zeta {
         .detach_and_log_err(cx);
     }
 
-    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
-        match self.edit_prediction_model {
-            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
-            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
-        }
+    async fn handle_rejected_predictions(
+        rx: UnboundedReceiver<EditPredictionRejection>,
+        client: Arc<Client>,
+        llm_token: LlmApiToken,
+        app_version: Version,
+        background_executor: BackgroundExecutor,
+    ) {
+        let mut rx = std::pin::pin!(rx.peekable());
+        let mut batched = Vec::new();
 
-        let client = self.client.clone();
-        let llm_token = self.llm_token.clone();
-        let app_version = AppVersion::global(cx);
-        let last_rejection = self.rejected_predictions.last().cloned();
-        let Some(last_rejection) = last_rejection else {
-            return Task::ready(anyhow::Ok(()));
-        };
+        while let Some(rejection) = rx.next().await {
+            batched.push(rejection);
 
-        let body = serde_json::to_string(&RejectEditPredictionsBody {
-            rejections: self.rejected_predictions.clone(),
-        })
-        .ok();
+            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
+                select_biased! {
+                    next = rx.as_mut().peek().fuse() => {
+                        if next.is_some() {
+                            continue;
+                        }
+                    }
+                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
+                }
+            }
 
-        cx.spawn(async move |this, cx| {
             let url = client
                 .http_client()
-                .build_zed_llm_url("/predict_edits/reject", &[])?;
+                .build_zed_llm_url("/predict_edits/reject", &[])
+                .unwrap();
 
-            cx.background_spawn(Self::send_api_request::<()>(
-                move |builder| {
-                    let req = builder.uri(url.as_ref()).body(body.clone().into());
-                    Ok(req?)
+            let flush_count = batched
+                .len()
+                // in case items have accumulated after failure
+                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
+            let start = batched.len() - flush_count;
+
+            let body = RejectEditPredictionsBodyRef {
+                rejections: &batched[start..],
+            };
+
+            let result = Self::send_api_request::<()>(
+                |builder| {
+                    let req = builder
+                        .uri(url.as_ref())
+                        .body(serde_json::to_string(&body)?.into());
+                    anyhow::Ok(req?)
                 },
-                client,
-                llm_token,
-                app_version,
-            ))
-            .await
-            .context("Failed to reject edit predictions")?;
+                client.clone(),
+                llm_token.clone(),
+                app_version.clone(),
+            )
+            .await;
 
-            this.update(cx, |this, _| {
-                if let Some(ix) = this
-                    .rejected_predictions
-                    .iter()
-                    .position(|rejection| rejection.request_id == last_rejection.request_id)
-                {
-                    this.rejected_predictions.drain(..ix + 1);
-                }
-            })
-        })
+            if result.log_err().is_some() {
+                batched.drain(start..);
+            }
+        }
     }
 
     fn reject_current_prediction(
         &mut self,
         reason: EditPredictionRejectReason,
         project: &Entity<Project>,
-        cx: &mut Context<Self>,
     ) {
         if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
             project_state.pending_predictions.clear();
             if let Some(prediction) = project_state.current_prediction.take() {
-                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
+                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
             }
         };
     }
@@ -984,26 +997,14 @@ impl Zeta {
         prediction_id: EditPredictionId,
         reason: EditPredictionRejectReason,
         was_shown: bool,
-        cx: &mut Context<Self>,
     ) {
-        self.rejected_predictions.push(EditPredictionRejection {
-            request_id: prediction_id.to_string(),
-            reason,
-            was_shown,
-        });
-
-        let reached_request_limit =
-            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
-        let reject_tx = self.reject_predictions_tx.clone();
-        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
-            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-            if !reached_request_limit {
-                cx.background_executor()
-                    .timer(REJECT_REQUEST_DEBOUNCE)
-                    .await;
-            }
-            reject_tx.unbounded_send(()).log_err();
-        }));
+        self.reject_predictions_tx
+            .unbounded_send(EditPredictionRejection {
+                request_id: prediction_id.to_string(),
+                reason,
+                was_shown,
+            })
+            .log_err();
     }
 
     fn is_refreshing(&self, project: &Entity<Project>) -> bool {
@@ -1211,7 +1212,6 @@ impl Zeta {
                                     this.reject_current_prediction(
                                         EditPredictionRejectReason::Replaced,
                                         &project,
-                                        cx,
                                     );
 
                                     Some(new_prediction)
@@ -1220,7 +1220,6 @@ impl Zeta {
                                         new_prediction.prediction.id,
                                         EditPredictionRejectReason::CurrentPreferred,
                                         false,
-                                        cx,
                                     );
                                     None
                                 }
@@ -1229,7 +1228,7 @@ impl Zeta {
                             }
                         }
                         Err(reject_reason) => {
-                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+                            this.reject_prediction(prediction_result.id, reject_reason, false);
                             None
                         }
                     }
@@ -2906,7 +2905,7 @@ fn feature_gate_predict_edits_actions(cx: &mut App) {
 
 #[cfg(test)]
 mod tests {
-    use std::{path::Path, sync::Arc};
+    use std::{path::Path, sync::Arc, time::Duration};
 
     use client::UserStore;
     use clock::FakeSystemClock;
@@ -2933,7 +2932,7 @@ mod tests {
     use util::path;
     use uuid::Uuid;
 
-    use crate::{BufferEditPrediction, Zeta};
+    use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
 
     #[gpui::test]
     async fn test_current_state(cx: &mut TestAppContext) {
@@ -3035,8 +3034,8 @@ mod tests {
             .unwrap();
         refresh_task.await.unwrap();
 
-        zeta.update(cx, |zeta, cx| {
-            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
+        zeta.update(cx, |zeta, _cx| {
+            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
         });
 
         // Prediction for another file
@@ -3545,14 +3544,17 @@ mod tests {
         let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let position = snapshot.anchor_before(language::Point::new(1, 3));
 
+        // start two refresh tasks
         zeta.update(cx, |zeta, cx| {
-            // start two refresh tasks
             zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
 
+        let (_, respond_first) = requests.predict.next().await.unwrap();
+
+        zeta.update(cx, |zeta, cx| {
             zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
         });
 
-        let (_, respond_first) = requests.predict.next().await.unwrap();
         let (_, respond_second) = requests.predict.next().await.unwrap();
 
         // wait for throttle
@@ -3631,18 +3633,22 @@ mod tests {
         let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let position = snapshot.anchor_before(language::Point::new(1, 3));
 
+        // start two refresh tasks
         zeta.update(cx, |zeta, cx| {
-            // start two refresh tasks
             zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+        });
+
+        let (_, respond_first) = requests.predict.next().await.unwrap();
+
+        zeta.update(cx, |zeta, cx| {
             zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
         });
 
+        let (_, respond_second) = requests.predict.next().await.unwrap();
+
         // wait for throttle, so requests are sent
         cx.run_until_parked();
 
-        let (_, respond_first) = requests.predict.next().await.unwrap();
-        let (_, respond_second) = requests.predict.next().await.unwrap();
-
         zeta.update(cx, |zeta, cx| {
             // start a third request
             zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
@@ -3736,6 +3742,118 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_rejections_flushing(cx: &mut TestAppContext) {
+        let (zeta, mut requests) = init_test(cx);
+
+        zeta.update(cx, |zeta, _cx| {
+            zeta.reject_prediction(
+                EditPredictionId("test-1".into()),
+                EditPredictionRejectReason::Discarded,
+                false,
+            );
+            zeta.reject_prediction(
+                EditPredictionId("test-2".into()),
+                EditPredictionRejectReason::Canceled,
+                true,
+            );
+        });
+
+        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+        cx.run_until_parked();
+
+        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+        respond_tx.send(()).unwrap();
+
+        // batched
+        assert_eq!(reject_request.rejections.len(), 2);
+        assert_eq!(
+            reject_request.rejections[0],
+            EditPredictionRejection {
+                request_id: "test-1".to_string(),
+                reason: EditPredictionRejectReason::Discarded,
+                was_shown: false
+            }
+        );
+        assert_eq!(
+            reject_request.rejections[1],
+            EditPredictionRejection {
+                request_id: "test-2".to_string(),
+                reason: EditPredictionRejectReason::Canceled,
+                was_shown: true
+            }
+        );
+
+        // Reaching batch size limit sends without debounce
+        zeta.update(cx, |zeta, _cx| {
+            for i in 0..70 {
+                zeta.reject_prediction(
+                    EditPredictionId(format!("batch-{}", i).into()),
+                    EditPredictionRejectReason::Discarded,
+                    false,
+                );
+            }
+        });
+
+        // First MAX/2 items are sent immediately
+        cx.run_until_parked();
+        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+        respond_tx.send(()).unwrap();
+
+        assert_eq!(reject_request.rejections.len(), 50);
+        assert_eq!(reject_request.rejections[0].request_id, "batch-0");
+        assert_eq!(reject_request.rejections[49].request_id, "batch-49");
+
+        // Remaining items are debounced with the next batch
+        cx.executor().advance_clock(Duration::from_secs(15));
+        cx.run_until_parked();
+
+        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+        respond_tx.send(()).unwrap();
+
+        assert_eq!(reject_request.rejections.len(), 20);
+        assert_eq!(reject_request.rejections[0].request_id, "batch-50");
+        assert_eq!(reject_request.rejections[19].request_id, "batch-69");
+
+        // Request failure
+        zeta.update(cx, |zeta, _cx| {
+            zeta.reject_prediction(
+                EditPredictionId("retry-1".into()),
+                EditPredictionRejectReason::Discarded,
+                false,
+            );
+        });
+
+        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+        cx.run_until_parked();
+
+        let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
+        assert_eq!(reject_request.rejections.len(), 1);
+        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+        // Simulate failure
+        drop(_respond_tx);
+
+        // Add another rejection
+        zeta.update(cx, |zeta, _cx| {
+            zeta.reject_prediction(
+                EditPredictionId("retry-2".into()),
+                EditPredictionRejectReason::Discarded,
+                false,
+            );
+        });
+
+        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+        cx.run_until_parked();
+
+        // Retry should include both the failed item and the new one
+        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+        respond_tx.send(()).unwrap();
+
+        assert_eq!(reject_request.rejections.len(), 2);
+        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+        assert_eq!(reject_request.rejections[1].request_id, "retry-2");
+    }
+
     // Skipped until we start including diagnostics in prompt
     // #[gpui::test]
     // async fn test_request_diagnostics(cx: &mut TestAppContext) {