diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index ea7708c5769fd20b891a9a7c518a4ab7563b6401..c0dbc800618ec5d6d1b393ab9c88e441df3433c3 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -310,7 +310,8 @@ struct ProjectState { next_pending_prediction_id: usize, pending_predictions: ArrayVec, debug_tx: Option>, - last_prediction_refresh: Option<(EntityId, Instant)>, + last_edit_prediction_refresh: Option<(EntityId, Instant)>, + last_jump_prediction_refresh: Option<(EntityId, Instant)>, cancelled_predictions: HashSet, context: Entity, license_detection_watchers: HashMap>, @@ -444,6 +445,14 @@ impl PredictionRequestedBy { } } +const DIAGNOSTIC_LINES_RANGE: u32 = 20; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DiagnosticSearchScope { + Local, + Global, +} + #[derive(Debug)] struct PendingPrediction { id: usize, @@ -899,7 +908,8 @@ impl EditPredictionStore { cancelled_predictions: HashSet::default(), pending_predictions: ArrayVec::new(), next_pending_prediction_id: 0, - last_prediction_refresh: None, + last_edit_prediction_refresh: None, + last_jump_prediction_refresh: None, license_detection_watchers: HashMap::default(), user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE), _subscriptions: [ @@ -1014,7 +1024,11 @@ impl EditPredictionStore { } project::Event::DiagnosticsUpdated { .. } => { if cx.has_flag::() { - self.refresh_prediction_from_diagnostics(project, cx); + self.refresh_prediction_from_diagnostics( + project, + DiagnosticSearchScope::Global, + cx, + ); } } _ => (), @@ -1458,38 +1472,45 @@ impl EditPredictionStore { position: language::Anchor, cx: &mut Context, ) { - self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { - let Some(request_task) = this - .update(cx, |this, cx| { - this.request_prediction( - &project, - &buffer, - position, - PredictEditsRequestTrigger::Other, - cx, - ) - }) - .log_err() - else { - return Task::ready(anyhow::Ok(None)); - }; - - cx.spawn(async move |_cx| { - request_task.await.map(|prediction_result| { - prediction_result.map(|prediction_result| { - ( - prediction_result, - PredictionRequestedBy::Buffer(buffer.entity_id()), + self.queue_prediction_refresh( + project.clone(), + PredictEditsRequestTrigger::Other, + buffer.entity_id(), + cx, + move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &buffer, + position, + PredictEditsRequestTrigger::Other, + cx, ) }) + .log_err() + else { + return Task::ready(anyhow::Ok(None)); + }; + + cx.spawn(async move |_cx| { + request_task.await.map(|prediction_result| { + prediction_result.map(|prediction_result| { + ( + prediction_result, + PredictionRequestedBy::Buffer(buffer.entity_id()), + ) + }) + }) }) - }) - }) + }, + ) } pub fn refresh_prediction_from_diagnostics( &mut self, project: Entity, + scope: DiagnosticSearchScope, cx: &mut Context, ) { let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { @@ -1499,79 +1520,96 @@ impl EditPredictionStore { // Prefer predictions from buffer if project_state.current_prediction.is_some() { return; - }; - - self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { - let Some((active_buffer, snapshot, cursor_point)) = this - .read_with(cx, |this, cx| { - let project_state = this.projects.get(&project.entity_id())?; - let (buffer, position) = project_state.active_buffer(&project, cx)?; - let snapshot = buffer.read(cx).snapshot(); - - if !Self::predictions_enabled_at(&snapshot, position, cx) { - return None; - } + } - let cursor_point = position - .map(|pos| pos.to_point(&snapshot)) - .unwrap_or_default(); + self.queue_prediction_refresh( + project.clone(), + PredictEditsRequestTrigger::Diagnostics, + project.entity_id(), + cx, + move |this, cx| { + let Some((active_buffer, snapshot, cursor_point)) = this + .read_with(cx, |this, cx| { + let project_state = this.projects.get(&project.entity_id())?; + let (buffer, position) = project_state.active_buffer(&project, cx)?; + let snapshot = buffer.read(cx).snapshot(); + + if !Self::predictions_enabled_at(&snapshot, position, cx) { + return None; + } - Some((buffer, snapshot, cursor_point)) - }) - .log_err() - .flatten() - else { - return Task::ready(anyhow::Ok(None)); - }; + let cursor_point = position + .map(|pos| pos.to_point(&snapshot)) + .unwrap_or_default(); - cx.spawn(async move |cx| { - let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer, - &snapshot, - Default::default(), - cursor_point, - &project, - cx, - ) - .await? + Some((buffer, snapshot, cursor_point)) + }) + .log_err() + .flatten() else { - return anyhow::Ok(None); + return Task::ready(anyhow::Ok(None)); }; - let Some(prediction_result) = this - .update(cx, |this, cx| { - this.request_prediction( - &project, - &jump_buffer, - jump_position, - PredictEditsRequestTrigger::Diagnostics, - cx, - ) - })? + cx.spawn(async move |cx| { + let diagnostic_search_range = match scope { + DiagnosticSearchScope::Local => { + let diagnostic_search_start = + cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + Point::new(diagnostic_search_start, 0) + ..Point::new(diagnostic_search_end, 0) + } + DiagnosticSearchScope::Global => Default::default(), + }; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) .await? - else { - return anyhow::Ok(None); - }; + else { + return anyhow::Ok(None); + }; - this.update(cx, |this, cx| { - Some(( - if this - .get_or_init_project(&project, cx) - .current_prediction - .is_none() - { - prediction_result - } else { - EditPredictionResult { - id: prediction_result.id, - prediction: Err(EditPredictionRejectReason::CurrentPreferred), - } - }, - PredictionRequestedBy::DiagnosticsUpdate, - )) + let Some(prediction_result) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &jump_buffer, + jump_position, + PredictEditsRequestTrigger::Diagnostics, + cx, + ) + })? + .await? + else { + return anyhow::Ok(None); + }; + + this.update(cx, |this, cx| { + Some(( + if this + .get_or_init_project(&project, cx) + .current_prediction + .is_none() + { + prediction_result + } else { + EditPredictionResult { + id: prediction_result.id, + prediction: Err(EditPredictionRejectReason::CurrentPreferred), + } + }, + PredictionRequestedBy::DiagnosticsUpdate, + )) + }) }) - }) - }); + }, + ); } fn predictions_enabled_at( @@ -1605,14 +1643,12 @@ impl EditPredictionStore { true } - #[cfg(not(test))] pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - #[cfg(test)] - pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; fn queue_prediction_refresh( &mut self, project: Entity, + request_trigger: PredictEditsRequestTrigger, throttle_entity: EntityId, cx: &mut Context, do_refresh: impl FnOnce( @@ -1622,20 +1658,34 @@ impl EditPredictionStore { -> Task>> + 'static, ) { + fn select_throttle( + project_state: &mut ProjectState, + request_trigger: PredictEditsRequestTrigger, + ) -> &mut Option<(EntityId, Instant)> { + match request_trigger { + PredictEditsRequestTrigger::Diagnostics => { + &mut project_state.last_jump_prediction_refresh + } + _ => &mut project_state.last_edit_prediction_refresh, + } + } + let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama; let drop_on_cancel = is_ollama; let max_pending_predictions = if is_ollama { 1 } else { 2 }; + let throttle_timeout = Self::THROTTLE_TIMEOUT; let project_state = self.get_or_init_project(&project, cx); let pending_prediction_id = project_state.next_pending_prediction_id; project_state.next_pending_prediction_id += 1; - let last_request = project_state.last_prediction_refresh; + let last_request = *select_throttle(project_state, request_trigger); let task = cx.spawn(async move |this, cx| { - if let Some((last_entity, last_timestamp)) = last_request - && throttle_entity == last_entity - && let Some(timeout) = - (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) - { + if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| { + if throttle_entity != last_entity { + return None; + } + (last_timestamp + throttle_timeout).checked_duration_since(Instant::now()) + }) { cx.background_executor().timer(timeout).await; } @@ -1644,11 +1694,12 @@ impl EditPredictionStore { let mut is_cancelled = true; this.update(cx, |this, cx| { let project_state = this.get_or_init_project(&project, cx); - if !project_state + let was_cancelled = project_state .cancelled_predictions - .remove(&pending_prediction_id) - { - project_state.last_prediction_refresh = Some((throttle_entity, Instant::now())); + .remove(&pending_prediction_id); + if !was_cancelled { + let new_refresh = (throttle_entity, Instant::now()); + *select_throttle(project_state, request_trigger) = Some(new_refresh); is_cancelled = false; } }) @@ -1785,8 +1836,6 @@ impl EditPredictionStore { allow_jump: bool, cx: &mut Context, ) -> Task>> { - const DIAGNOSTIC_LINES_RANGE: u32 = 20; - self.get_or_init_project(&project, cx); let project_state = self.projects.get(&project.entity_id()).unwrap(); let stored_events = project_state.events(cx); @@ -1828,14 +1877,14 @@ impl EditPredictionStore { let inputs = EditPredictionModelInput { project: project.clone(), - buffer: active_buffer.clone(), - snapshot: snapshot.clone(), + buffer: active_buffer, + snapshot: snapshot, position, events, related_files, recent_paths: project_state.recent_paths.clone(), trigger, - diagnostic_search_range: diagnostic_search_range.clone(), + diagnostic_search_range: diagnostic_search_range, debug_tx, user_actions, }; @@ -1861,33 +1910,14 @@ impl EditPredictionStore { cx.spawn(async move |this, cx| { let prediction = task.await?; - if prediction.is_none() && allow_jump { - let cursor_point = position.to_point(&snapshot); - if has_events - && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer.clone(), - &snapshot, - diagnostic_search_range, - cursor_point, - &project, + if prediction.is_none() && allow_jump && has_events { + this.update(cx, |this, cx| { + this.refresh_prediction_from_diagnostics( + project, + DiagnosticSearchScope::Local, cx, - ) - .await? - { - return this - .update(cx, |this, cx| { - this.request_prediction_internal( - project, - jump_buffer, - jump_position, - trigger, - false, - cx, - ) - })? - .await; - } - + ); + })?; return anyhow::Ok(None); } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 242a2bf3fff5f0eb87b183ec6c65280cbe75256a..0b6d68df320a581200b5f04a9bcf19db11b495b7 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -8,7 +8,7 @@ use cloud_llm_client::{ predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response}, }; use futures::{ - AsyncReadExt, StreamExt, + AsyncReadExt, FutureExt, StreamExt, channel::{mpsc, oneshot}, }; use gpui::App; @@ -1375,6 +1375,107 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n", + "bar.md": "Hola!\nComo\nAdios\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_project(&project, cx); + ep_store.register_buffer(&buffer, &project, cx); + }); + + // First edit request - no prior edit, so not throttled. + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap(); + edit_response_tx.send(empty_response()).unwrap(); + cx.run_until_parked(); + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit). + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap(); + jump_response_tx.send(empty_response()).unwrap(); + cx.run_until_parked(); + + // Second edit request - should be throttled by the first edit. + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + assert_no_predict_request_ready(&mut requests.predict); + + // Second jump request - should be throttled by the first jump. + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_diagnostics( + project.clone(), + DiagnosticSearchScope::Global, + cx, + ); + }); + assert_no_predict_request_ready(&mut requests.predict); + + // Wait for both throttles to expire. + cx.background_executor + .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT); + cx.background_executor.run_until_parked(); + cx.run_until_parked(); + + // Both requests should now go through. + let (_request_1, response_tx_1) = requests.predict.next().await.unwrap(); + response_tx_1.send(empty_response()).unwrap(); + cx.run_until_parked(); + + let (_request_2, response_tx_2) = requests.predict.next().await.unwrap(); + response_tx_2.send(empty_response()).unwrap(); + cx.run_until_parked(); +} + #[gpui::test] async fn test_rejections_flushing(cx: &mut TestAppContext) { let (ep_store, mut requests) = init_test_with_fake_client(cx); @@ -1596,10 +1697,28 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi } } +fn empty_response() -> PredictEditsV3Response { + PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: String::new(), + } +} + fn prompt_from_request(request: &PredictEditsV3Request) -> String { zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default()) } +fn assert_no_predict_request_ready( + requests: &mut mpsc::UnboundedReceiver<( + PredictEditsV3Request, + oneshot::Sender, + )>, +) { + if requests.next().now_or_never().flatten().is_some() { + panic!("Unexpected prediction request while throttled."); + } +} + struct RequestChannels { predict: mpsc::UnboundedReceiver<( PredictEditsV3Request,