@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
- MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
+ MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
@@ -19,8 +19,10 @@ use edit_prediction_context::{
SyntaxIndex, SyntaxIndexState,
};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
+use futures::channel::mpsc::UnboundedReceiver;
use futures::channel::{mpsc, oneshot};
-use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
+use gpui::BackgroundExecutor;
use gpui::{
App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
@@ -100,6 +102,7 @@ actions!(
const EVENT_COUNT_MAX: usize = 6;
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
+const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
pub struct SweepFeatureFlag;
@@ -195,9 +198,7 @@ pub struct Zeta {
edit_prediction_model: ZetaEditPredictionModel,
pub sweep_ai: SweepAi,
data_collection_choice: DataCollectionChoice,
- rejected_predictions: Vec<EditPredictionRejection>,
- reject_predictions_tx: mpsc::UnboundedSender<()>,
- reject_predictions_debounce_task: Option<Task<()>>,
+ reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
}
@@ -325,13 +326,8 @@ impl ZetaProject {
return;
};
- this.update(cx, |this, cx| {
- this.reject_prediction(
- prediction_id,
- EditPredictionRejectReason::Canceled,
- false,
- cx,
- );
+ this.update(cx, |this, _cx| {
+ this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
})
.ok();
})
@@ -504,14 +500,24 @@ impl Zeta {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choice();
- let (reject_tx, mut reject_rx) = mpsc::unbounded();
- cx.spawn(async move |this, cx| {
- while let Some(()) = reject_rx.next().await {
- this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
- .await
- .log_err();
+ let llm_token = LlmApiToken::default();
+
+ let (reject_tx, reject_rx) = mpsc::unbounded();
+ cx.background_spawn({
+ let client = client.clone();
+ let llm_token = llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let background_executor = cx.background_executor().clone();
+ async move {
+ Self::handle_rejected_predictions(
+ reject_rx,
+ client,
+ llm_token,
+ app_version,
+ background_executor,
+ )
+ .await
}
- anyhow::Ok(())
})
.detach();
@@ -520,7 +526,7 @@ impl Zeta {
client,
user_store,
options: DEFAULT_OPTIONS,
- llm_token: LlmApiToken::default(),
+ llm_token,
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _listener, _event, cx| {
@@ -540,8 +546,6 @@ impl Zeta {
edit_prediction_model: ZetaEditPredictionModel::Zeta2,
sweep_ai: SweepAi::new(cx),
data_collection_choice,
- rejected_predictions: Vec::new(),
- reject_predictions_debounce_task: None,
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
shown_predictions: Default::default(),
@@ -901,64 +905,73 @@ impl Zeta {
.detach_and_log_err(cx);
}
- fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
- ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
- }
+ async fn handle_rejected_predictions(
+ rx: UnboundedReceiver<EditPredictionRejection>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ background_executor: BackgroundExecutor,
+ ) {
+ let mut rx = std::pin::pin!(rx.peekable());
+ let mut batched = Vec::new();
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- let last_rejection = self.rejected_predictions.last().cloned();
- let Some(last_rejection) = last_rejection else {
- return Task::ready(anyhow::Ok(()));
- };
+ while let Some(rejection) = rx.next().await {
+ batched.push(rejection);
- let body = serde_json::to_string(&RejectEditPredictionsBody {
- rejections: self.rejected_predictions.clone(),
- })
- .ok();
+ if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
+ select_biased! {
+ next = rx.as_mut().peek().fuse() => {
+ if next.is_some() {
+ continue;
+ }
+ }
+ () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
+ }
+ }
- cx.spawn(async move |this, cx| {
let url = client
.http_client()
- .build_zed_llm_url("/predict_edits/reject", &[])?;
+ .build_zed_llm_url("/predict_edits/reject", &[])
+ .unwrap();
- cx.background_spawn(Self::send_api_request::<()>(
- move |builder| {
- let req = builder.uri(url.as_ref()).body(body.clone().into());
- Ok(req?)
+ let flush_count = batched
+ .len()
+ // in case items have accumulated after failure
+ .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
+ let start = batched.len() - flush_count;
+
+ let body = RejectEditPredictionsBodyRef {
+ rejections: &batched[start..],
+ };
+
+ let result = Self::send_api_request::<()>(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&body)?.into());
+ anyhow::Ok(req?)
},
- client,
- llm_token,
- app_version,
- ))
- .await
- .context("Failed to reject edit predictions")?;
+ client.clone(),
+ llm_token.clone(),
+ app_version.clone(),
+ )
+ .await;
- this.update(cx, |this, _| {
- if let Some(ix) = this
- .rejected_predictions
- .iter()
- .position(|rejection| rejection.request_id == last_rejection.request_id)
- {
- this.rejected_predictions.drain(..ix + 1);
- }
- })
- })
+ if result.log_err().is_some() {
+ batched.drain(start..);
+ }
+ }
}
fn reject_current_prediction(
&mut self,
reason: EditPredictionRejectReason,
project: &Entity<Project>,
- cx: &mut Context<Self>,
) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.pending_predictions.clear();
if let Some(prediction) = project_state.current_prediction.take() {
- self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
+ self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
}
};
}
@@ -984,26 +997,14 @@ impl Zeta {
prediction_id: EditPredictionId,
reason: EditPredictionRejectReason,
was_shown: bool,
- cx: &mut Context<Self>,
) {
- self.rejected_predictions.push(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- });
-
- let reached_request_limit =
- self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
- let reject_tx = self.reject_predictions_tx.clone();
- self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
- const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
- if !reached_request_limit {
- cx.background_executor()
- .timer(REJECT_REQUEST_DEBOUNCE)
- .await;
- }
- reject_tx.unbounded_send(()).log_err();
- }));
+ self.reject_predictions_tx
+ .unbounded_send(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ })
+ .log_err();
}
fn is_refreshing(&self, project: &Entity<Project>) -> bool {
@@ -1211,7 +1212,6 @@ impl Zeta {
this.reject_current_prediction(
EditPredictionRejectReason::Replaced,
&project,
- cx,
);
Some(new_prediction)
@@ -1220,7 +1220,6 @@ impl Zeta {
new_prediction.prediction.id,
EditPredictionRejectReason::CurrentPreferred,
false,
- cx,
);
None
}
@@ -1229,7 +1228,7 @@ impl Zeta {
}
}
Err(reject_reason) => {
- this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+ this.reject_prediction(prediction_result.id, reject_reason, false);
None
}
}
@@ -2906,7 +2905,7 @@ fn feature_gate_predict_edits_actions(cx: &mut App) {
#[cfg(test)]
mod tests {
- use std::{path::Path, sync::Arc};
+ use std::{path::Path, sync::Arc, time::Duration};
use client::UserStore;
use clock::FakeSystemClock;
@@ -2933,7 +2932,7 @@ mod tests {
use util::path;
use uuid::Uuid;
- use crate::{BufferEditPrediction, Zeta};
+ use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
#[gpui::test]
async fn test_current_state(cx: &mut TestAppContext) {
@@ -3035,8 +3034,8 @@ mod tests {
.unwrap();
refresh_task.await.unwrap();
- zeta.update(cx, |zeta, cx| {
- zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
+ zeta.update(cx, |zeta, _cx| {
+ zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
});
// Prediction for another file
@@ -3545,14 +3544,17 @@ mod tests {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
+ // start two refresh tasks
zeta.update(cx, |zeta, cx| {
- // start two refresh tasks
zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
let (_, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
@@ -3631,18 +3633,22 @@ mod tests {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
+ // start two refresh tasks
zeta.update(cx, |zeta, cx| {
- // start two refresh tasks
zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
// wait for throttle, so requests are sent
cx.run_until_parked();
- let (_, respond_first) = requests.predict.next().await.unwrap();
- let (_, respond_second) = requests.predict.next().await.unwrap();
-
zeta.update(cx, |zeta, cx| {
// start a third request
zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
@@ -3736,6 +3742,118 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_rejections_flushing(cx: &mut TestAppContext) {
+ let (zeta, mut requests) = init_test(cx);
+
+ zeta.update(cx, |zeta, _cx| {
+ zeta.reject_prediction(
+ EditPredictionId("test-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ zeta.reject_prediction(
+ EditPredictionId("test-2".into()),
+ EditPredictionRejectReason::Canceled,
+ true,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ // batched
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(
+ reject_request.rejections[0],
+ EditPredictionRejection {
+ request_id: "test-1".to_string(),
+ reason: EditPredictionRejectReason::Discarded,
+ was_shown: false
+ }
+ );
+ assert_eq!(
+ reject_request.rejections[1],
+ EditPredictionRejection {
+ request_id: "test-2".to_string(),
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: true
+ }
+ );
+
+ // Reaching batch size limit sends without debounce
+ zeta.update(cx, |zeta, _cx| {
+ for i in 0..70 {
+ zeta.reject_prediction(
+ EditPredictionId(format!("batch-{}", i).into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ }
+ });
+
+ // First MAX/2 items are sent immediately
+ cx.run_until_parked();
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 50);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-0");
+ assert_eq!(reject_request.rejections[49].request_id, "batch-49");
+
+ // Remaining items are debounced with the next batch
+ cx.executor().advance_clock(Duration::from_secs(15));
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 20);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-50");
+ assert_eq!(reject_request.rejections[19].request_id, "batch-69");
+
+ // Request failure
+ zeta.update(cx, |zeta, _cx| {
+ zeta.reject_prediction(
+ EditPredictionId("retry-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
+ assert_eq!(reject_request.rejections.len(), 1);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ // Simulate failure
+ drop(_respond_tx);
+
+ // Add another rejection
+ zeta.update(cx, |zeta, _cx| {
+ zeta.reject_prediction(
+ EditPredictionId("retry-2".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ // Retry should include both the failed item and the new one
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ assert_eq!(reject_request.rejections[1].request_id, "retry-2");
+ }
+
// Skipped until we start including diagnostics in prompt
// #[gpui::test]
// async fn test_request_diagnostics(cx: &mut TestAppContext) {