edit prediction: Respect enabled settings when refreshing from diagnostics (#44640)

Agus Zubiaga created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs              |  85 ++++
crates/edit_prediction/src/edit_prediction_tests.rs        |  65 ++--
crates/edit_prediction/src/zed_edit_prediction_delegate.rs | 120 ++++----
3 files changed, 165 insertions(+), 105 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -283,6 +283,18 @@ impl ProjectState {
         })
         .detach()
     }
+
+    fn active_buffer(
+        &self,
+        project: &Entity<Project>,
+        cx: &App,
+    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
+        let project = project.read(cx);
+        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
+        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
+        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
+        Some((active_buffer, registered_buffer.last_position))
+    }
 }
 
 #[derive(Debug, Clone)]
@@ -373,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
 
 struct RegisteredBuffer {
     snapshot: BufferSnapshot,
+    last_position: Option<Anchor>,
     _subscriptions: [gpui::Subscription; 2],
 }
 
@@ -795,6 +808,7 @@ impl EditPredictionStore {
                 let project_entity_id = project.entity_id();
                 entry.insert(RegisteredBuffer {
                     snapshot,
+                    last_position: None,
                     _subscriptions: [
                         cx.subscribe(buffer, {
                             let project = project.downgrade();
@@ -882,13 +896,21 @@ impl EditPredictionStore {
         });
     }
 
-    fn current_prediction_for_buffer(
-        &self,
+    fn prediction_at(
+        &mut self,
         buffer: &Entity<Buffer>,
+        position: Option<language::Anchor>,
         project: &Entity<Project>,
         cx: &App,
     ) -> Option<BufferEditPrediction<'_>> {
-        let project_state = self.projects.get(&project.entity_id())?;
+        let project_state = self.projects.get_mut(&project.entity_id())?;
+        if let Some(position) = position
+            && let Some(buffer) = project_state
+                .registered_buffers
+                .get_mut(&buffer.entity_id())
+        {
+            buffer.last_position = Some(position);
+        }
 
         let CurrentEditPrediction {
             requested_by,
@@ -1131,12 +1153,21 @@ impl EditPredictionStore {
         };
 
         self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
-            let Some(open_buffer_task) = project
-                .update(cx, |project, cx| {
-                    project
-                        .active_entry()
-                        .and_then(|entry| project.path_for_entry(entry, cx))
-                        .map(|path| project.open_buffer(path, cx))
+            let Some((active_buffer, snapshot, cursor_point)) = this
+                .read_with(cx, |this, cx| {
+                    let project_state = this.projects.get(&project.entity_id())?;
+                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
+                    let snapshot = buffer.read(cx).snapshot();
+
+                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
+                        return None;
+                    }
+
+                    let cursor_point = position
+                        .map(|pos| pos.to_point(&snapshot))
+                        .unwrap_or_default();
+
+                    Some((buffer, snapshot, cursor_point))
                 })
                 .log_err()
                 .flatten()
@@ -1145,14 +1176,11 @@ impl EditPredictionStore {
             };
 
             cx.spawn(async move |cx| {
-                let active_buffer = open_buffer_task.await?;
-                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
                 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
                     active_buffer,
                     &snapshot,
                     Default::default(),
-                    Default::default(),
+                    cursor_point,
                     &project,
                     cx,
                 )
@@ -1197,6 +1225,37 @@ impl EditPredictionStore {
         });
     }
 
+    fn predictions_enabled_at(
+        snapshot: &BufferSnapshot,
+        position: Option<language::Anchor>,
+        cx: &App,
+    ) -> bool {
+        let file = snapshot.file();
+        let all_settings = all_language_settings(file, cx);
+        if !all_settings.show_edit_predictions(snapshot.language(), cx)
+            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
+        {
+            return false;
+        }
+
+        if let Some(last_position) = position {
+            let settings = snapshot.settings_at(last_position, cx);
+
+            if !settings.edit_predictions_disabled_in.is_empty()
+                && let Some(scope) = snapshot.language_scope_at(last_position)
+                && let Some(scope_name) = scope.override_name()
+                && settings
+                    .edit_predictions_disabled_in
+                    .iter()
+                    .any(|s| s == scope_name)
+            {
+                return false;
+            }
+        }
+
+        true
+    }
+
     #[cfg(not(test))]
     pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
     #[cfg(test)]

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
     .await;
     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 
-    ep_store.update(cx, |ep_store, cx| {
-        ep_store.register_project(&project, cx);
-    });
-
     let buffer1 = project
         .update(cx, |project, cx| {
             let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
@@ -60,6 +56,11 @@ async fn test_current_state(cx: &mut TestAppContext) {
     let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
     let position = snapshot1.anchor_before(language::Point::new(1, 3));
 
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.register_project(&project, cx);
+        ep_store.register_buffer(&buffer1, &project, cx);
+    });
+
     // Prediction for current file
 
     ep_store.update(cx, |ep_store, cx| {
@@ -84,9 +85,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         let prediction = ep_store
-            .current_prediction_for_buffer(&buffer1, &project, cx)
+            .prediction_at(&buffer1, None, &project, cx)
             .unwrap();
         assert_matches!(prediction, BufferEditPrediction::Local { .. });
     });
@@ -140,9 +141,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
         .unwrap();
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         let prediction = ep_store
-            .current_prediction_for_buffer(&buffer1, &project, cx)
+            .prediction_at(&buffer1, None, &project, cx)
             .unwrap();
         assert_matches!(
             prediction,
@@ -158,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
         .await
         .unwrap();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         let prediction = ep_store
-            .current_prediction_for_buffer(&buffer2, &project, cx)
+            .prediction_at(&buffer2, None, &project, cx)
             .unwrap();
         assert_matches!(prediction, BufferEditPrediction::Local { .. });
     });
@@ -344,10 +345,10 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         assert!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .is_none()
         );
     });
@@ -404,10 +405,10 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         assert!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .is_none()
         );
     });
@@ -469,10 +470,10 @@ async fn test_replace_current(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -492,11 +493,11 @@ async fn test_replace_current(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // second replaces first
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -551,10 +552,10 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -586,11 +587,11 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // first is preferred over second
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -657,11 +658,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // current prediction is second
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -675,11 +676,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // current prediction is still second, since first was cancelled
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -768,11 +769,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // current prediction is first
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -786,11 +787,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // current prediction is still first, since second was cancelled
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,
@@ -804,11 +805,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 
     cx.run_until_parked();
 
-    ep_store.read_with(cx, |ep_store, cx| {
+    ep_store.update(cx, |ep_store, cx| {
         // third completes and replaces first
         assert_eq!(
             ep_store
-                .current_prediction_for_buffer(&buffer, &project, cx)
+                .prediction_at(&buffer, None, &project, cx)
                 .unwrap()
                 .id
                 .0,

crates/edit_prediction/src/zed_edit_prediction_delegate.rs 🔗

@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
             return;
         }
 
-        if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
-            && let BufferEditPrediction::Local { prediction } = current
-            && prediction.interpolate(buffer.read(cx)).is_some()
-        {
-            return;
-        }
-
         self.store.update(cx, |store, cx| {
+            if let Some(current) =
+                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
+                && let BufferEditPrediction::Local { prediction } = current
+                && prediction.interpolate(buffer.read(cx)).is_some()
+            {
+                return;
+            }
+
             store.refresh_context(&self.project, &buffer, cursor_position, cx);
             store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
         });
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
         cursor_position: language::Anchor,
         cx: &mut Context<Self>,
     ) -> Option<edit_prediction_types::EditPrediction> {
-        let prediction =
-            self.store
-                .read(cx)
-                .current_prediction_for_buffer(buffer, &self.project, cx)?;
-
-        let prediction = match prediction {
-            BufferEditPrediction::Local { prediction } => prediction,
-            BufferEditPrediction::Jump { prediction } => {
-                return Some(edit_prediction_types::EditPrediction::Jump {
-                    id: Some(prediction.id.to_string().into()),
-                    snapshot: prediction.snapshot.clone(),
-                    target: prediction.edits.first().unwrap().0.start,
-                });
-            }
-        };
+        self.store.update(cx, |store, cx| {
+            let prediction =
+                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
+
+            let prediction = match prediction {
+                BufferEditPrediction::Local { prediction } => prediction,
+                BufferEditPrediction::Jump { prediction } => {
+                    return Some(edit_prediction_types::EditPrediction::Jump {
+                        id: Some(prediction.id.to_string().into()),
+                        snapshot: prediction.snapshot.clone(),
+                        target: prediction.edits.first().unwrap().0.start,
+                    });
+                }
+            };
 
-        let buffer = buffer.read(cx);
-        let snapshot = buffer.snapshot();
+            let buffer = buffer.read(cx);
+            let snapshot = buffer.snapshot();
 
-        let Some(edits) = prediction.interpolate(&snapshot) else {
-            self.store.update(cx, |store, _cx| {
+            let Some(edits) = prediction.interpolate(&snapshot) else {
                 store.reject_current_prediction(
                     EditPredictionRejectReason::InterpolatedEmpty,
                     &self.project,
                 );
-            });
-            return None;
-        };
-
-        let cursor_row = cursor_position.to_point(&snapshot).row;
-        let (closest_edit_ix, (closest_edit_range, _)) =
-            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
-                let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
-                let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
-                cmp::min(distance_from_start, distance_from_end)
-            })?;
-
-        let mut edit_start_ix = closest_edit_ix;
-        for (range, _) in edits[..edit_start_ix].iter().rev() {
-            let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
-                - range.end.to_point(&snapshot).row;
-            if distance_from_closest_edit <= 1 {
-                edit_start_ix -= 1;
-            } else {
-                break;
+                return None;
+            };
+
+            let cursor_row = cursor_position.to_point(&snapshot).row;
+            let (closest_edit_ix, (closest_edit_range, _)) =
+                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+                    let distance_from_start =
+                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
+                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
+                    cmp::min(distance_from_start, distance_from_end)
+                })?;
+
+            let mut edit_start_ix = closest_edit_ix;
+            for (range, _) in edits[..edit_start_ix].iter().rev() {
+                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
+                    - range.end.to_point(&snapshot).row;
+                if distance_from_closest_edit <= 1 {
+                    edit_start_ix -= 1;
+                } else {
+                    break;
+                }
             }
-        }
 
-        let mut edit_end_ix = closest_edit_ix + 1;
-        for (range, _) in &edits[edit_end_ix..] {
-            let distance_from_closest_edit =
-                range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
-            if distance_from_closest_edit <= 1 {
-                edit_end_ix += 1;
-            } else {
-                break;
+            let mut edit_end_ix = closest_edit_ix + 1;
+            for (range, _) in &edits[edit_end_ix..] {
+                let distance_from_closest_edit = range.start.to_point(buffer).row
+                    - closest_edit_range.end.to_point(&snapshot).row;
+                if distance_from_closest_edit <= 1 {
+                    edit_end_ix += 1;
+                } else {
+                    break;
+                }
             }
-        }
 
-        Some(edit_prediction_types::EditPrediction::Local {
-            id: Some(prediction.id.to_string().into()),
-            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
-            edit_preview: Some(prediction.edit_preview.clone()),
+            Some(edit_prediction_types::EditPrediction::Local {
+                id: Some(prediction.id.to_string().into()),
+                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+                edit_preview: Some(prediction.edit_preview.clone()),
+            })
         })
     }
 }