From 2db237aa52ff4aaa0b55b95167f6bc5a04272ad3 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 2 Dec 2025 13:22:16 -0300 Subject: [PATCH] Limit edit prediction reject batches to max (#43965) We currently attempt to flush all rejected predictions at once even if we have accumulated more than `MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST`. Instead, we will now flush as many as possible, and then keep the rest for the next batch. Release Notes: - N/A --- .../cloud_llm_client/src/cloud_llm_client.rs | 7 +- crates/zeta/src/provider.rs | 11 +- crates/zeta/src/zeta.rs | 306 ++++++++++++------ 3 files changed, 221 insertions(+), 103 deletions(-) diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 35916bd6801485c8c2bfde9330a47da19025f2c3..917929a985c85610b907e682792e132cb84d8403 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -206,11 +206,16 @@ pub struct AcceptEditPredictionBody { pub request_id: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct RejectEditPredictionsBody { pub rejections: Vec, } +#[derive(Debug, Clone, Serialize)] +pub struct RejectEditPredictionsBodyRef<'a> { + pub rejections: &'a [EditPredictionRejection], +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct EditPredictionRejection { pub request_id: String, diff --git a/crates/zeta/src/provider.rs b/crates/zeta/src/provider.rs index 5a2117397b7dd94d1fd61c4fb9880ebe447dbc1f..019d780e579c079f745f56136bdbd3a4add76b50 100644 --- a/crates/zeta/src/provider.rs +++ b/crates/zeta/src/provider.rs @@ -132,12 +132,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn discard(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, cx| { - zeta.reject_current_prediction( - EditPredictionRejectReason::Discarded, - &self.project, - cx, - ); + self.zeta.update(cx, |zeta, _cx| { + zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project); }); } @@ -173,11 +169,10 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let snapshot = buffer.snapshot(); let Some(edits) = prediction.interpolate(&snapshot) else { - self.zeta.update(cx, |zeta, cx| { + self.zeta.update(cx, |zeta, _cx| { zeta.reject_current_prediction( EditPredictionRejectReason::InterpolatedEmpty, &self.project, - cx, ); }); return None; diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 909f21200cc7c055adb80b1e510e6f13e7fc9784..dba90abbc839566781d18308e53c4b0faa96e1d7 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, - MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; @@ -19,8 +19,10 @@ use edit_prediction_context::{ SyntaxIndex, SyntaxIndexState, }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; +use futures::channel::mpsc::UnboundedReceiver; use futures::channel::{mpsc, oneshot}; -use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _}; +use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased}; +use gpui::BackgroundExecutor; use gpui::{ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, http_client::{self, AsyncBody, Method}, @@ -100,6 +102,7 @@ actions!( const EVENT_COUNT_MAX: usize = 6; const CHANGE_GROUPING_LINE_SPAN: u32 = 8; const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; +const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); pub struct SweepFeatureFlag; @@ -195,9 +198,7 @@ pub struct Zeta { edit_prediction_model: ZetaEditPredictionModel, pub sweep_ai: SweepAi, data_collection_choice: DataCollectionChoice, - rejected_predictions: Vec, - reject_predictions_tx: mpsc::UnboundedSender<()>, - reject_predictions_debounce_task: Option>, + reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, } @@ -325,13 +326,8 @@ impl ZetaProject { return; }; - this.update(cx, |this, cx| { - this.reject_prediction( - prediction_id, - EditPredictionRejectReason::Canceled, - false, - cx, - ); + this.update(cx, |this, _cx| { + this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false); }) .ok(); }) @@ -504,14 +500,24 @@ impl Zeta { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choice(); - let (reject_tx, mut reject_rx) = mpsc::unbounded(); - cx.spawn(async move |this, cx| { - while let Some(()) = reject_rx.next().await { - this.update(cx, |this, cx| this.flush_rejected_predictions(cx))? - .await - .log_err(); + let llm_token = LlmApiToken::default(); + + let (reject_tx, reject_rx) = mpsc::unbounded(); + cx.background_spawn({ + let client = client.clone(); + let llm_token = llm_token.clone(); + let app_version = AppVersion::global(cx); + let background_executor = cx.background_executor().clone(); + async move { + Self::handle_rejected_predictions( + reject_rx, + client, + llm_token, + app_version, + background_executor, + ) + .await } - anyhow::Ok(()) }) .detach(); @@ -520,7 +526,7 @@ impl Zeta { client, user_store, options: DEFAULT_OPTIONS, - llm_token: LlmApiToken::default(), + llm_token, _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, |this, _listener, _event, cx| { @@ -540,8 +546,6 @@ impl Zeta { edit_prediction_model: ZetaEditPredictionModel::Zeta2, sweep_ai: SweepAi::new(cx), data_collection_choice, - rejected_predictions: Vec::new(), - reject_predictions_debounce_task: None, reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), @@ -901,64 +905,73 @@ impl Zeta { .detach_and_log_err(cx); } - fn flush_rejected_predictions(&mut self, cx: &mut Context) -> Task> { - match self.edit_prediction_model { - ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} - ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())), - } + async fn handle_rejected_predictions( + rx: UnboundedReceiver, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + background_executor: BackgroundExecutor, + ) { + let mut rx = std::pin::pin!(rx.peekable()); + let mut batched = Vec::new(); - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - let last_rejection = self.rejected_predictions.last().cloned(); - let Some(last_rejection) = last_rejection else { - return Task::ready(anyhow::Ok(())); - }; + while let Some(rejection) = rx.next().await { + batched.push(rejection); - let body = serde_json::to_string(&RejectEditPredictionsBody { - rejections: self.rejected_predictions.clone(), - }) - .ok(); + if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { + select_biased! { + next = rx.as_mut().peek().fuse() => { + if next.is_some() { + continue; + } + } + () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {}, + } + } - cx.spawn(async move |this, cx| { let url = client .http_client() - .build_zed_llm_url("/predict_edits/reject", &[])?; + .build_zed_llm_url("/predict_edits/reject", &[]) + .unwrap(); - cx.background_spawn(Self::send_api_request::<()>( - move |builder| { - let req = builder.uri(url.as_ref()).body(body.clone().into()); - Ok(req?) + let flush_count = batched + .len() + // in case items have accumulated after failure + .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST); + let start = batched.len() - flush_count; + + let body = RejectEditPredictionsBodyRef { + rejections: &batched[start..], + }; + + let result = Self::send_api_request::<()>( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&body)?.into()); + anyhow::Ok(req?) }, - client, - llm_token, - app_version, - )) - .await - .context("Failed to reject edit predictions")?; + client.clone(), + llm_token.clone(), + app_version.clone(), + ) + .await; - this.update(cx, |this, _| { - if let Some(ix) = this - .rejected_predictions - .iter() - .position(|rejection| rejection.request_id == last_rejection.request_id) - { - this.rejected_predictions.drain(..ix + 1); - } - }) - }) + if result.log_err().is_some() { + batched.drain(start..); + } + } } fn reject_current_prediction( &mut self, reason: EditPredictionRejectReason, project: &Entity, - cx: &mut Context, ) { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.pending_predictions.clear(); if let Some(prediction) = project_state.current_prediction.take() { - self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx); + self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown); } }; } @@ -984,26 +997,14 @@ impl Zeta { prediction_id: EditPredictionId, reason: EditPredictionRejectReason, was_shown: bool, - cx: &mut Context, ) { - self.rejected_predictions.push(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - }); - - let reached_request_limit = - self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2; - let reject_tx = self.reject_predictions_tx.clone(); - self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| { - const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); - if !reached_request_limit { - cx.background_executor() - .timer(REJECT_REQUEST_DEBOUNCE) - .await; - } - reject_tx.unbounded_send(()).log_err(); - })); + self.reject_predictions_tx + .unbounded_send(EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + }) + .log_err(); } fn is_refreshing(&self, project: &Entity) -> bool { @@ -1211,7 +1212,6 @@ impl Zeta { this.reject_current_prediction( EditPredictionRejectReason::Replaced, &project, - cx, ); Some(new_prediction) @@ -1220,7 +1220,6 @@ impl Zeta { new_prediction.prediction.id, EditPredictionRejectReason::CurrentPreferred, false, - cx, ); None } @@ -1229,7 +1228,7 @@ impl Zeta { } } Err(reject_reason) => { - this.reject_prediction(prediction_result.id, reject_reason, false, cx); + this.reject_prediction(prediction_result.id, reject_reason, false); None } } @@ -2906,7 +2905,7 @@ fn feature_gate_predict_edits_actions(cx: &mut App) { #[cfg(test)] mod tests { - use std::{path::Path, sync::Arc}; + use std::{path::Path, sync::Arc, time::Duration}; use client::UserStore; use clock::FakeSystemClock; @@ -2933,7 +2932,7 @@ mod tests { use util::path; use uuid::Uuid; - use crate::{BufferEditPrediction, Zeta}; + use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta}; #[gpui::test] async fn test_current_state(cx: &mut TestAppContext) { @@ -3035,8 +3034,8 @@ mod tests { .unwrap(); refresh_task.await.unwrap(); - zeta.update(cx, |zeta, cx| { - zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx); + zeta.update(cx, |zeta, _cx| { + zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); }); // Prediction for another file @@ -3545,14 +3544,17 @@ mod tests { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let position = snapshot.anchor_before(language::Point::new(1, 3)); + // start two refresh tasks zeta.update(cx, |zeta, cx| { - // start two refresh tasks zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + let (_, respond_first) = requests.predict.next().await.unwrap(); + + zeta.update(cx, |zeta, cx| { zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_first) = requests.predict.next().await.unwrap(); let (_, respond_second) = requests.predict.next().await.unwrap(); // wait for throttle @@ -3631,18 +3633,22 @@ mod tests { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let position = snapshot.anchor_before(language::Point::new(1, 3)); + // start two refresh tasks zeta.update(cx, |zeta, cx| { - // start two refresh tasks zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + + zeta.update(cx, |zeta, cx| { zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); + let (_, respond_second) = requests.predict.next().await.unwrap(); + // wait for throttle, so requests are sent cx.run_until_parked(); - let (_, respond_first) = requests.predict.next().await.unwrap(); - let (_, respond_second) = requests.predict.next().await.unwrap(); - zeta.update(cx, |zeta, cx| { // start a third request zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); @@ -3736,6 +3742,118 @@ mod tests { ); } + #[gpui::test] + async fn test_rejections_flushing(cx: &mut TestAppContext) { + let (zeta, mut requests) = init_test(cx); + + zeta.update(cx, |zeta, _cx| { + zeta.reject_prediction( + EditPredictionId("test-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + zeta.reject_prediction( + EditPredictionId("test-2".into()), + EditPredictionRejectReason::Canceled, + true, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + // batched + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!( + reject_request.rejections[0], + EditPredictionRejection { + request_id: "test-1".to_string(), + reason: EditPredictionRejectReason::Discarded, + was_shown: false + } + ); + assert_eq!( + reject_request.rejections[1], + EditPredictionRejection { + request_id: "test-2".to_string(), + reason: EditPredictionRejectReason::Canceled, + was_shown: true + } + ); + + // Reaching batch size limit sends without debounce + zeta.update(cx, |zeta, _cx| { + for i in 0..70 { + zeta.reject_prediction( + EditPredictionId(format!("batch-{}", i).into()), + EditPredictionRejectReason::Discarded, + false, + ); + } + }); + + // First MAX/2 items are sent immediately + cx.run_until_parked(); + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 50); + assert_eq!(reject_request.rejections[0].request_id, "batch-0"); + assert_eq!(reject_request.rejections[49].request_id, "batch-49"); + + // Remaining items are debounced with the next batch + cx.executor().advance_clock(Duration::from_secs(15)); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 20); + assert_eq!(reject_request.rejections[0].request_id, "batch-50"); + assert_eq!(reject_request.rejections[19].request_id, "batch-69"); + + // Request failure + zeta.update(cx, |zeta, _cx| { + zeta.reject_prediction( + EditPredictionId("retry-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, _respond_tx) = requests.reject.next().await.unwrap(); + assert_eq!(reject_request.rejections.len(), 1); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + // Simulate failure + drop(_respond_tx); + + // Add another rejection + zeta.update(cx, |zeta, _cx| { + zeta.reject_prediction( + EditPredictionId("retry-2".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + // Retry should include both the failed item and the new one + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + assert_eq!(reject_request.rejections[1].request_id, "retry-2"); + } + // Skipped until we start including diagnostics in prompt // #[gpui::test] // async fn test_request_diagnostics(cx: &mut TestAppContext) {