@@ -283,6 +283,18 @@ impl ProjectState {
})
.detach()
}
+
+ fn active_buffer(
+ &self,
+ project: &Entity<Project>,
+ cx: &App,
+ ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
+ let project = project.read(cx);
+ let active_path = project.path_for_entry(project.active_entry()?, cx)?;
+ let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
+ let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
+ Some((active_buffer, registered_buffer.last_position))
+ }
}
#[derive(Debug, Clone)]
@@ -373,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
struct RegisteredBuffer {
snapshot: BufferSnapshot,
+ last_position: Option<Anchor>,
_subscriptions: [gpui::Subscription; 2],
}
@@ -795,6 +808,7 @@ impl EditPredictionStore {
let project_entity_id = project.entity_id();
entry.insert(RegisteredBuffer {
snapshot,
+ last_position: None,
_subscriptions: [
cx.subscribe(buffer, {
let project = project.downgrade();
@@ -882,13 +896,21 @@ impl EditPredictionStore {
});
}
- fn current_prediction_for_buffer(
- &self,
+ fn prediction_at(
+ &mut self,
buffer: &Entity<Buffer>,
+ position: Option<language::Anchor>,
project: &Entity<Project>,
cx: &App,
) -> Option<BufferEditPrediction<'_>> {
- let project_state = self.projects.get(&project.entity_id())?;
+ let project_state = self.projects.get_mut(&project.entity_id())?;
+ if let Some(position) = position
+ && let Some(buffer) = project_state
+ .registered_buffers
+ .get_mut(&buffer.entity_id())
+ {
+ buffer.last_position = Some(position);
+ }
let CurrentEditPrediction {
requested_by,
@@ -1131,12 +1153,21 @@ impl EditPredictionStore {
};
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
- let Some(open_buffer_task) = project
- .update(cx, |project, cx| {
- project
- .active_entry()
- .and_then(|entry| project.path_for_entry(entry, cx))
- .map(|path| project.open_buffer(path, cx))
+ let Some((active_buffer, snapshot, cursor_point)) = this
+ .read_with(cx, |this, cx| {
+ let project_state = this.projects.get(&project.entity_id())?;
+ let (buffer, position) = project_state.active_buffer(&project, cx)?;
+ let snapshot = buffer.read(cx).snapshot();
+
+ if !Self::predictions_enabled_at(&snapshot, position, cx) {
+ return None;
+ }
+
+ let cursor_point = position
+ .map(|pos| pos.to_point(&snapshot))
+ .unwrap_or_default();
+
+ Some((buffer, snapshot, cursor_point))
})
.log_err()
.flatten()
@@ -1145,14 +1176,11 @@ impl EditPredictionStore {
};
cx.spawn(async move |cx| {
- let active_buffer = open_buffer_task.await?;
- let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
&snapshot,
Default::default(),
- Default::default(),
+ cursor_point,
&project,
cx,
)
@@ -1197,6 +1225,37 @@ impl EditPredictionStore {
});
}
+ fn predictions_enabled_at(
+ snapshot: &BufferSnapshot,
+ position: Option<language::Anchor>,
+ cx: &App,
+ ) -> bool {
+ let file = snapshot.file();
+ let all_settings = all_language_settings(file, cx);
+ if !all_settings.show_edit_predictions(snapshot.language(), cx)
+ || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
+ {
+ return false;
+ }
+
+ if let Some(last_position) = position {
+ let settings = snapshot.settings_at(last_position, cx);
+
+ if !settings.edit_predictions_disabled_in.is_empty()
+ && let Some(scope) = snapshot.language_scope_at(last_position)
+ && let Some(scope_name) = scope.override_name()
+ && settings
+ .edit_predictions_disabled_in
+ .iter()
+ .any(|s| s == scope_name)
+ {
+ return false;
+ }
+ }
+
+ true
+ }
+
#[cfg(not(test))]
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
#[cfg(test)]
@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
- ep_store.update(cx, |ep_store, cx| {
- ep_store.register_project(&project, cx);
- });
-
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
@@ -60,6 +56,11 @@ async fn test_current_state(cx: &mut TestAppContext) {
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_project(&project, cx);
+ ep_store.register_buffer(&buffer1, &project, cx);
+ });
+
// Prediction for current file
ep_store.update(cx, |ep_store, cx| {
@@ -84,9 +85,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
- .current_prediction_for_buffer(&buffer1, &project, cx)
+ .prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -140,9 +141,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
.unwrap();
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
- .current_prediction_for_buffer(&buffer1, &project, cx)
+ .prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(
prediction,
@@ -158,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await
.unwrap();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
- .current_prediction_for_buffer(&buffer2, &project, cx)
+ .prediction_at(&buffer2, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -344,10 +345,10 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -404,10 +405,10 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -469,10 +470,10 @@ async fn test_replace_current(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -492,11 +493,11 @@ async fn test_replace_current(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// second replaces first
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -551,10 +552,10 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -586,11 +587,11 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// first is preferred over second
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -657,11 +658,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// current prediction is second
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -675,11 +676,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// current prediction is still second, since first was cancelled
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -768,11 +769,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// current prediction is first
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -786,11 +787,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// current prediction is still first, since second was cancelled
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -804,11 +805,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
cx.run_until_parked();
- ep_store.read_with(cx, |ep_store, cx| {
+ ep_store.update(cx, |ep_store, cx| {
// third completes and replaces first
assert_eq!(
ep_store
- .current_prediction_for_buffer(&buffer, &project, cx)
+ .prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
return;
}
- if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
- && let BufferEditPrediction::Local { prediction } = current
- && prediction.interpolate(buffer.read(cx)).is_some()
- {
- return;
- }
-
self.store.update(cx, |store, cx| {
+ if let Some(current) =
+ store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
+ && let BufferEditPrediction::Local { prediction } = current
+ && prediction.interpolate(buffer.read(cx)).is_some()
+ {
+ return;
+ }
+
store.refresh_context(&self.project, &buffer, cursor_position, cx);
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction_types::EditPrediction> {
- let prediction =
- self.store
- .read(cx)
- .current_prediction_for_buffer(buffer, &self.project, cx)?;
-
- let prediction = match prediction {
- BufferEditPrediction::Local { prediction } => prediction,
- BufferEditPrediction::Jump { prediction } => {
- return Some(edit_prediction_types::EditPrediction::Jump {
- id: Some(prediction.id.to_string().into()),
- snapshot: prediction.snapshot.clone(),
- target: prediction.edits.first().unwrap().0.start,
- });
- }
- };
+ self.store.update(cx, |store, cx| {
+ let prediction =
+ store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
+
+ let prediction = match prediction {
+ BufferEditPrediction::Local { prediction } => prediction,
+ BufferEditPrediction::Jump { prediction } => {
+ return Some(edit_prediction_types::EditPrediction::Jump {
+ id: Some(prediction.id.to_string().into()),
+ snapshot: prediction.snapshot.clone(),
+ target: prediction.edits.first().unwrap().0.start,
+ });
+ }
+ };
- let buffer = buffer.read(cx);
- let snapshot = buffer.snapshot();
+ let buffer = buffer.read(cx);
+ let snapshot = buffer.snapshot();
- let Some(edits) = prediction.interpolate(&snapshot) else {
- self.store.update(cx, |store, _cx| {
+ let Some(edits) = prediction.interpolate(&snapshot) else {
store.reject_current_prediction(
EditPredictionRejectReason::InterpolatedEmpty,
&self.project,
);
- });
- return None;
- };
-
- let cursor_row = cursor_position.to_point(&snapshot).row;
- let (closest_edit_ix, (closest_edit_range, _)) =
- edits.iter().enumerate().min_by_key(|(_, (range, _))| {
- let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
- let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
- cmp::min(distance_from_start, distance_from_end)
- })?;
-
- let mut edit_start_ix = closest_edit_ix;
- for (range, _) in edits[..edit_start_ix].iter().rev() {
- let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- - range.end.to_point(&snapshot).row;
- if distance_from_closest_edit <= 1 {
- edit_start_ix -= 1;
- } else {
- break;
+ return None;
+ };
+
+ let cursor_row = cursor_position.to_point(&snapshot).row;
+ let (closest_edit_ix, (closest_edit_range, _)) =
+ edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+ let distance_from_start =
+ cursor_row.abs_diff(range.start.to_point(&snapshot).row);
+ let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
+ cmp::min(distance_from_start, distance_from_end)
+ })?;
+
+ let mut edit_start_ix = closest_edit_ix;
+ for (range, _) in edits[..edit_start_ix].iter().rev() {
+ let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
+ - range.end.to_point(&snapshot).row;
+ if distance_from_closest_edit <= 1 {
+ edit_start_ix -= 1;
+ } else {
+ break;
+ }
}
- }
- let mut edit_end_ix = closest_edit_ix + 1;
- for (range, _) in &edits[edit_end_ix..] {
- let distance_from_closest_edit =
- range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
- if distance_from_closest_edit <= 1 {
- edit_end_ix += 1;
- } else {
- break;
+ let mut edit_end_ix = closest_edit_ix + 1;
+ for (range, _) in &edits[edit_end_ix..] {
+ let distance_from_closest_edit = range.start.to_point(buffer).row
+ - closest_edit_range.end.to_point(&snapshot).row;
+ if distance_from_closest_edit <= 1 {
+ edit_end_ix += 1;
+ } else {
+ break;
+ }
}
- }
- Some(edit_prediction_types::EditPrediction::Local {
- id: Some(prediction.id.to_string().into()),
- edits: edits[edit_start_ix..edit_end_ix].to_vec(),
- edit_preview: Some(prediction.edit_preview.clone()),
+ Some(edit_prediction_types::EditPrediction::Local {
+ id: Some(prediction.id.to_string().into()),
+ edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+ edit_preview: Some(prediction.edit_preview.clone()),
+ })
})
}
}