diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index b0d4a5f4d69c357fb0a153bee267a64dc0c465dd..11151c1fc19437655075c00589d02725131445ed 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -283,6 +283,18 @@ impl ProjectState { }) .detach() } + + fn active_buffer( + &self, + project: &Entity, + cx: &App, + ) -> Option<(Entity, Option)> { + let project = project.read(cx); + let active_path = project.path_for_entry(project.active_entry()?, cx)?; + let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?; + let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?; + Some((active_buffer, registered_buffer.last_position)) + } } #[derive(Debug, Clone)] @@ -373,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> { struct RegisteredBuffer { snapshot: BufferSnapshot, + last_position: Option, _subscriptions: [gpui::Subscription; 2], } @@ -795,6 +808,7 @@ impl EditPredictionStore { let project_entity_id = project.entity_id(); entry.insert(RegisteredBuffer { snapshot, + last_position: None, _subscriptions: [ cx.subscribe(buffer, { let project = project.downgrade(); @@ -882,13 +896,21 @@ impl EditPredictionStore { }); } - fn current_prediction_for_buffer( - &self, + fn prediction_at( + &mut self, buffer: &Entity, + position: Option, project: &Entity, cx: &App, ) -> Option> { - let project_state = self.projects.get(&project.entity_id())?; + let project_state = self.projects.get_mut(&project.entity_id())?; + if let Some(position) = position + && let Some(buffer) = project_state + .registered_buffers + .get_mut(&buffer.entity_id()) + { + buffer.last_position = Some(position); + } let CurrentEditPrediction { requested_by, @@ -1131,12 +1153,21 @@ impl EditPredictionStore { }; self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { - let Some(open_buffer_task) = project - .update(cx, |project, cx| { - project - .active_entry() - .and_then(|entry| project.path_for_entry(entry, cx)) - .map(|path| project.open_buffer(path, cx)) + let Some((active_buffer, snapshot, cursor_point)) = this + .read_with(cx, |this, cx| { + let project_state = this.projects.get(&project.entity_id())?; + let (buffer, position) = project_state.active_buffer(&project, cx)?; + let snapshot = buffer.read(cx).snapshot(); + + if !Self::predictions_enabled_at(&snapshot, position, cx) { + return None; + } + + let cursor_point = position + .map(|pos| pos.to_point(&snapshot)) + .unwrap_or_default(); + + Some((buffer, snapshot, cursor_point)) }) .log_err() .flatten() @@ -1145,14 +1176,11 @@ impl EditPredictionStore { }; cx.spawn(async move |cx| { - let active_buffer = open_buffer_task.await?; - let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( active_buffer, &snapshot, Default::default(), - Default::default(), + cursor_point, &project, cx, ) @@ -1197,6 +1225,37 @@ impl EditPredictionStore { }); } + fn predictions_enabled_at( + snapshot: &BufferSnapshot, + position: Option, + cx: &App, + ) -> bool { + let file = snapshot.file(); + let all_settings = all_language_settings(file, cx); + if !all_settings.show_edit_predictions(snapshot.language(), cx) + || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx)) + { + return false; + } + + if let Some(last_position) = position { + let settings = snapshot.settings_at(last_position, cx); + + if !settings.edit_predictions_disabled_in.is_empty() + && let Some(scope) = snapshot.language_scope_at(last_position) + && let Some(scope_name) = scope.override_name() + && settings + .edit_predictions_disabled_in + .iter() + .any(|s| s == scope_name) + { + return false; + } + } + + true + } + #[cfg(not(test))] pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); #[cfg(test)] diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index f6465b14cbd1b3357349071bc5eda399253b5328..9e4baa78ef4564ce4348ef1b51085ba0a6abdffc 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) { .await; let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - ep_store.update(cx, |ep_store, cx| { - ep_store.register_project(&project, cx); - }); - let buffer1 = project .update(cx, |project, cx| { let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); @@ -60,6 +56,11 @@ async fn test_current_state(cx: &mut TestAppContext) { let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); let position = snapshot1.anchor_before(language::Point::new(1, 3)); + ep_store.update(cx, |ep_store, cx| { + ep_store.register_project(&project, cx); + ep_store.register_buffer(&buffer1, &project, cx); + }); + // Prediction for current file ep_store.update(cx, |ep_store, cx| { @@ -84,9 +85,9 @@ async fn test_current_state(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { let prediction = ep_store - .current_prediction_for_buffer(&buffer1, &project, cx) + .prediction_at(&buffer1, None, &project, cx) .unwrap(); assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); @@ -140,9 +141,9 @@ async fn test_current_state(cx: &mut TestAppContext) { .unwrap(); cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { let prediction = ep_store - .current_prediction_for_buffer(&buffer1, &project, cx) + .prediction_at(&buffer1, None, &project, cx) .unwrap(); assert_matches!( prediction, @@ -158,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) { .await .unwrap(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { let prediction = ep_store - .current_prediction_for_buffer(&buffer2, &project, cx) + .prediction_at(&buffer2, None, &project, cx) .unwrap(); assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); @@ -344,10 +345,10 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { assert!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .is_none() ); }); @@ -404,10 +405,10 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { assert!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .is_none() ); }); @@ -469,10 +470,10 @@ async fn test_replace_current(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -492,11 +493,11 @@ async fn test_replace_current(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // second replaces first assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -551,10 +552,10 @@ async fn test_current_preferred(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -586,11 +587,11 @@ async fn test_current_preferred(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // first is preferred over second assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -657,11 +658,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // current prediction is second assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -675,11 +676,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // current prediction is still second, since first was cancelled assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -768,11 +769,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // current prediction is first assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -786,11 +787,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // current prediction is still first, since second was cancelled assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, @@ -804,11 +805,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { cx.run_until_parked(); - ep_store.read_with(cx, |ep_store, cx| { + ep_store.update(cx, |ep_store, cx| { // third completes and replaces first assert_eq!( ep_store - .current_prediction_for_buffer(&buffer, &project, cx) + .prediction_at(&buffer, None, &project, cx) .unwrap() .id .0, diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index 91371d539beca012e2ded4e9ec8702c8db39bd8a..6dcf7092240de64381ded611b47c2dd5940d6770 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { return; } - if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx) - && let BufferEditPrediction::Local { prediction } = current - && prediction.interpolate(buffer.read(cx)).is_some() - { - return; - } - self.store.update(cx, |store, cx| { + if let Some(current) = + store.prediction_at(&buffer, Some(cursor_position), &self.project, cx) + && let BufferEditPrediction::Local { prediction } = current + && prediction.interpolate(buffer.read(cx)).is_some() + { + return; + } + store.refresh_context(&self.project, &buffer, cursor_position, cx); store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) }); @@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { cursor_position: language::Anchor, cx: &mut Context, ) -> Option { - let prediction = - self.store - .read(cx) - .current_prediction_for_buffer(buffer, &self.project, cx)?; - - let prediction = match prediction { - BufferEditPrediction::Local { prediction } => prediction, - BufferEditPrediction::Jump { prediction } => { - return Some(edit_prediction_types::EditPrediction::Jump { - id: Some(prediction.id.to_string().into()), - snapshot: prediction.snapshot.clone(), - target: prediction.edits.first().unwrap().0.start, - }); - } - }; + self.store.update(cx, |store, cx| { + let prediction = + store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?; + + let prediction = match prediction { + BufferEditPrediction::Local { prediction } => prediction, + BufferEditPrediction::Jump { prediction } => { + return Some(edit_prediction_types::EditPrediction::Jump { + id: Some(prediction.id.to_string().into()), + snapshot: prediction.snapshot.clone(), + target: prediction.edits.first().unwrap().0.start, + }); + } + }; - let buffer = buffer.read(cx); - let snapshot = buffer.snapshot(); + let buffer = buffer.read(cx); + let snapshot = buffer.snapshot(); - let Some(edits) = prediction.interpolate(&snapshot) else { - self.store.update(cx, |store, _cx| { + let Some(edits) = prediction.interpolate(&snapshot) else { store.reject_current_prediction( EditPredictionRejectReason::InterpolatedEmpty, &self.project, ); - }); - return None; - }; - - let cursor_row = cursor_position.to_point(&snapshot).row; - let (closest_edit_ix, (closest_edit_range, _)) = - edits.iter().enumerate().min_by_key(|(_, (range, _))| { - let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row); - let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row); - cmp::min(distance_from_start, distance_from_end) - })?; - - let mut edit_start_ix = closest_edit_ix; - for (range, _) in edits[..edit_start_ix].iter().rev() { - let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row - - range.end.to_point(&snapshot).row; - if distance_from_closest_edit <= 1 { - edit_start_ix -= 1; - } else { - break; + return None; + }; + + let cursor_row = cursor_position.to_point(&snapshot).row; + let (closest_edit_ix, (closest_edit_range, _)) = + edits.iter().enumerate().min_by_key(|(_, (range, _))| { + let distance_from_start = + cursor_row.abs_diff(range.start.to_point(&snapshot).row); + let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row); + cmp::min(distance_from_start, distance_from_end) + })?; + + let mut edit_start_ix = closest_edit_ix; + for (range, _) in edits[..edit_start_ix].iter().rev() { + let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row + - range.end.to_point(&snapshot).row; + if distance_from_closest_edit <= 1 { + edit_start_ix -= 1; + } else { + break; + } } - } - let mut edit_end_ix = closest_edit_ix + 1; - for (range, _) in &edits[edit_end_ix..] { - let distance_from_closest_edit = - range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row; - if distance_from_closest_edit <= 1 { - edit_end_ix += 1; - } else { - break; + let mut edit_end_ix = closest_edit_ix + 1; + for (range, _) in &edits[edit_end_ix..] { + let distance_from_closest_edit = range.start.to_point(buffer).row + - closest_edit_range.end.to_point(&snapshot).row; + if distance_from_closest_edit <= 1 { + edit_end_ix += 1; + } else { + break; + } } - } - Some(edit_prediction_types::EditPrediction::Local { - id: Some(prediction.id.to_string().into()), - edits: edits[edit_start_ix..edit_end_ix].to_vec(), - edit_preview: Some(prediction.edit_preview.clone()), + Some(edit_prediction_types::EditPrediction::Local { + id: Some(prediction.id.to_string().into()), + edits: edits[edit_start_ix..edit_end_ix].to_vec(), + edit_preview: Some(prediction.edit_preview.clone()), + }) }) } }