@@ -231,6 +231,71 @@ pub enum UserActionType {
pub struct StoredEvent {
pub event: Arc<zeta_prompt::Event>,
pub old_snapshot: TextBufferSnapshot,
+ pub edit_range: Range<Anchor>,
+}
+
+impl StoredEvent {
+ fn can_merge(
+ &self,
+ next_old_event: &&&StoredEvent,
+ new_snapshot: &TextBufferSnapshot,
+ last_edit_range: &Range<Anchor>,
+ ) -> bool {
+ // Events must be for the same buffer
+ if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
+ return false;
+ }
+
+ let a_is_predicted = matches!(
+ self.event.as_ref(),
+ zeta_prompt::Event::BufferChange {
+ predicted: true,
+ ..
+ }
+ );
+ let b_is_predicted = matches!(
+ next_old_event.event.as_ref(),
+ zeta_prompt::Event::BufferChange {
+ predicted: true,
+ ..
+ }
+ );
+
+ // If events come from the same source (both predicted or both manual) then
+ // we would have coalesced them already.
+ if a_is_predicted == b_is_predicted {
+ return false;
+ }
+
+ let left_range = self.edit_range.to_point(new_snapshot);
+ let right_range = next_old_event.edit_range.to_point(new_snapshot);
+ let latest_range = last_edit_range.to_point(&new_snapshot);
+
+ // Events near to the latest edit are not merged if their sources differ.
+ if lines_between_ranges(&left_range, &latest_range)
+ .min(lines_between_ranges(&right_range, &latest_range))
+ <= CHANGE_GROUPING_LINE_SPAN
+ {
+ return false;
+ }
+
+ // Events that are distant from each other are not merged.
+ if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
+ return false;
+ }
+
+ true
+ }
+}
+
+fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
+ if left.start > right.end {
+ return left.start.row - right.end.row;
+ }
+ if right.start > left.end {
+ return right.start.row - left.end.row;
+ }
+ 0
}
struct ProjectState {
@@ -260,18 +325,6 @@ impl ProjectState {
}
pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
- self.events
- .iter()
- .cloned()
- .chain(
- self.last_event
- .as_ref()
- .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
- )
- .collect()
- }
-
- pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
.cloned()
@@ -430,6 +483,7 @@ struct LastEvent {
old_file: Option<Arc<dyn File>>,
new_file: Option<Arc<dyn File>>,
edit_range: Option<Range<Anchor>>,
+ predicted: bool,
snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
last_edit_time: Option<Instant>,
}
@@ -454,7 +508,8 @@ impl LastEvent {
})
});
- let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
+ let (diff, edit_range) =
+ compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
if path == old_path && diff.is_empty() {
None
@@ -465,9 +520,10 @@ impl LastEvent {
path,
diff,
in_open_source_repo,
- // TODO: Actually detect if this edit was predicted or not
- predicted: false,
+ predicted: self.predicted,
}),
+ edit_range: self.new_snapshot.anchor_before(edit_range.start)
+ ..self.new_snapshot.anchor_before(edit_range.end),
old_snapshot: self.old_snapshot.clone(),
})
}
@@ -484,6 +540,7 @@ impl LastEvent {
old_file: self.old_file.clone(),
new_file: self.new_file.clone(),
edit_range: None,
+ predicted: self.predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: self.last_edit_time,
};
@@ -494,6 +551,7 @@ impl LastEvent {
old_file: self.old_file.clone(),
new_file: self.new_file.clone(),
edit_range: None,
+ predicted: self.predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: self.last_edit_time,
};
@@ -505,7 +563,7 @@ impl LastEvent {
pub(crate) fn compute_diff_between_snapshots(
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
-) -> Option<String> {
+) -> Option<(String, Range<Point>)> {
let edits: Vec<Edit<usize>> = new_snapshot
.edits_since::<usize>(&old_snapshot.version)
.collect();
@@ -545,7 +603,7 @@ pub(crate) fn compute_diff_between_snapshots(
new_context_start_row,
);
- Some(diff)
+ Some((diff, new_start_point..new_end_point))
}
fn buffer_path_with_id_fallback(
@@ -716,17 +774,6 @@ impl EditPredictionStore {
.unwrap_or_default()
}
- pub fn edit_history_for_project_with_pause_split_last_event(
- &self,
- project: &Entity<Project>,
- cx: &App,
- ) -> Vec<StoredEvent> {
- self.projects
- .get(&project.entity_id())
- .map(|project_state| project_state.events_split_by_pause(cx))
- .unwrap_or_default()
- }
-
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
@@ -1011,7 +1058,7 @@ impl EditPredictionStore {
if let language::BufferEvent::Edited = event
&& let Some(project) = project.upgrade()
{
- this.report_changes_for_buffer(&buffer, &project, cx);
+ this.report_changes_for_buffer(&buffer, &project, false, cx);
}
}
}),
@@ -1032,6 +1079,7 @@ impl EditPredictionStore {
&mut self,
buffer: &Entity<Buffer>,
project: &Entity<Project>,
+ is_predicted: bool,
cx: &mut Context<Self>,
) {
let project_state = self.get_or_init_project(project, cx);
@@ -1065,30 +1113,32 @@ impl EditPredictionStore {
last_offset = Some(edit.new.end);
}
- if num_edits > 0 {
- let action_type = match (total_deleted, total_inserted, num_edits) {
- (0, ins, n) if ins == n => UserActionType::InsertChar,
- (0, _, _) => UserActionType::InsertSelection,
- (del, 0, n) if del == n => UserActionType::DeleteChar,
- (_, 0, _) => UserActionType::DeleteSelection,
- (_, ins, n) if ins == n => UserActionType::InsertChar,
- (_, _, _) => UserActionType::InsertSelection,
- };
+ let Some(edit_range) = edit_range else {
+ return;
+ };
- if let Some(offset) = last_offset {
- let point = new_snapshot.offset_to_point(offset);
- let timestamp_epoch_ms = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .map(|d| d.as_millis() as u64)
- .unwrap_or(0);
- project_state.record_user_action(UserActionRecord {
- action_type,
- buffer_id: buffer.entity_id(),
- line_number: point.row,
- offset,
- timestamp_epoch_ms,
- });
- }
+ let action_type = match (total_deleted, total_inserted, num_edits) {
+ (0, ins, n) if ins == n => UserActionType::InsertChar,
+ (0, _, _) => UserActionType::InsertSelection,
+ (del, 0, n) if del == n => UserActionType::DeleteChar,
+ (_, 0, _) => UserActionType::DeleteSelection,
+ (_, ins, n) if ins == n => UserActionType::InsertChar,
+ (_, _, _) => UserActionType::InsertSelection,
+ };
+
+ if let Some(offset) = last_offset {
+ let point = new_snapshot.offset_to_point(offset);
+ let timestamp_epoch_ms = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .map(|d| d.as_millis() as u64)
+ .unwrap_or(0);
+ project_state.record_user_action(UserActionRecord {
+ action_type,
+ buffer_id: buffer.entity_id(),
+ line_number: point.row,
+ offset,
+ timestamp_epoch_ms,
+ });
}
let events = &mut project_state.events;
@@ -1099,20 +1149,18 @@ impl EditPredictionStore {
== last_event.new_snapshot.remote_id()
&& old_snapshot.version == last_event.new_snapshot.version;
+ let prediction_source_changed = is_predicted != last_event.predicted;
+
let should_coalesce = is_next_snapshot_of_same_buffer
- && edit_range
+ && !prediction_source_changed
+ && last_event
+ .edit_range
.as_ref()
- .zip(last_event.edit_range.as_ref())
- .is_some_and(|(a, b)| {
- let a = a.to_point(&new_snapshot);
- let b = b.to_point(&new_snapshot);
- if a.start > b.end {
- a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
- } else if b.start > a.end {
- b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
- } else {
- true
- }
+ .is_some_and(|last_edit_range| {
+ lines_between_ranges(
+ &edit_range.to_point(&new_snapshot),
+ &last_edit_range.to_point(&new_snapshot),
+ ) <= CHANGE_GROUPING_LINE_SPAN
});
if should_coalesce {
@@ -1125,7 +1173,7 @@ impl EditPredictionStore {
Some(last_event.new_snapshot.clone());
}
- last_event.edit_range = edit_range;
+ last_event.edit_range = Some(edit_range);
last_event.new_snapshot = new_snapshot;
last_event.last_edit_time = Some(now);
return;
@@ -1141,12 +1189,15 @@ impl EditPredictionStore {
}
}
+ merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
+
project_state.last_event = Some(LastEvent {
old_file,
new_file,
old_snapshot,
new_snapshot,
- edit_range,
+ edit_range: Some(edit_range),
+ predicted: is_predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: Some(now),
});
@@ -1193,11 +1244,18 @@ impl EditPredictionStore {
}
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ let Some(current_prediction) = self
+ .projects
+ .get_mut(&project.entity_id())
+ .and_then(|project_state| project_state.current_prediction.take())
+ else {
return;
};
- let Some(current_prediction) = project_state.current_prediction.take() else {
+ self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
+
+ // can't hold &mut project_state ref across report_changes_for_buffer_call
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
return;
};
@@ -1719,7 +1777,7 @@ impl EditPredictionStore {
self.get_or_init_project(&project, cx);
let project_state = self.projects.get(&project.entity_id()).unwrap();
- let stored_events = project_state.events_split_by_pause(cx);
+ let stored_events = project_state.events(cx);
let has_events = !stored_events.is_empty();
let events: Vec<Arc<zeta_prompt::Event>> =
stored_events.into_iter().map(|e| e.event).collect();
@@ -2219,6 +2277,67 @@ impl EditPredictionStore {
}
}
+fn merge_trailing_events_if_needed(
+ events: &mut VecDeque<StoredEvent>,
+ end_snapshot: &TextBufferSnapshot,
+ latest_snapshot: &TextBufferSnapshot,
+ latest_edit_range: &Range<Anchor>,
+) {
+ let mut next_old_event = None;
+ let mut mergeable_count = 0;
+ for old_event in events.iter().rev() {
+ if let Some(next_old_event) = &next_old_event
+ && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
+ {
+ break;
+ }
+ mergeable_count += 1;
+ next_old_event = Some(old_event);
+ }
+
+ if mergeable_count <= 1 {
+ return;
+ }
+
+ let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
+ let oldest_event = events_to_merge.peek().unwrap();
+ let oldest_snapshot = oldest_event.old_snapshot.clone();
+
+ if let Some((diff, edited_range)) =
+ compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
+ {
+ let merged_event = match oldest_event.event.as_ref() {
+ zeta_prompt::Event::BufferChange {
+ old_path,
+ path,
+ in_open_source_repo,
+ ..
+ } => StoredEvent {
+ event: Arc::new(zeta_prompt::Event::BufferChange {
+ old_path: old_path.clone(),
+ path: path.clone(),
+ diff,
+ in_open_source_repo: *in_open_source_repo,
+ predicted: events_to_merge.all(|e| {
+ matches!(
+ e.event.as_ref(),
+ zeta_prompt::Event::BufferChange {
+ predicted: true,
+ ..
+ }
+ )
+ }),
+ }),
+ old_snapshot: oldest_snapshot.clone(),
+ edit_range: end_snapshot.anchor_before(edited_range.start)
+ ..end_snapshot.anchor_before(edited_range.end),
+ },
+ };
+ events.truncate(events.len() - mergeable_count);
+ events.push_back(merged_event);
+ }
+}
+
pub(crate) fn filter_redundant_excerpts(
mut related_files: Vec<RelatedFile>,
cursor_path: &Path,
@@ -356,26 +356,9 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
buffer.edit(vec![(19..19, "!")], None, cx);
});
- // Without time-based splitting, there is one event.
- let events = ep_store.update(cx, |ep_store, cx| {
- ep_store.edit_history_for_project(&project, cx)
- });
- assert_eq!(events.len(), 1);
- let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
- assert_eq!(
- diff.as_str(),
- indoc! {"
- @@ -1,3 +1,3 @@
- Hello!
- -
- +How are you?!
- Bye
- "}
- );
-
// With time-based splitting, there are two distinct events.
let events = ep_store.update(cx, |ep_store, cx| {
- ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
+ ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(events.len(), 2);
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
@@ -404,7 +387,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
}
#[gpui::test]
-async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
+async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
@@ -593,6 +576,278 @@ fn render_events(events: &[StoredEvent]) -> String {
.join("\n---\n")
}
+fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
+ events
+ .iter()
+ .map(|e| {
+ let zeta_prompt::Event::BufferChange {
+ diff, predicted, ..
+ } = e.event.as_ref();
+ let prefix = if *predicted { "predicted" } else { "manual" };
+ format!("{}\n{}", prefix, diff)
+ })
+ .collect()
+}
+
+#[gpui::test]
+async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
+ let (ep_store, _requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ });
+
+ // Case 1: Manual edits have `predicted` set to false.
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![indoc! {"
+ manual
+ @@ -1,4 +1,4 @@
+ -line 0
+ +LINE ZERO
+ line 1
+ line 2
+ line 3
+ "}]
+ );
+
+ // Case 2: Multiple successive manual edits near each other are merged into one
+ // event with `predicted` set to false.
+ buffer.update(cx, |buffer, cx| {
+ let offset = Point::new(1, 0).to_offset(buffer);
+ let end = Point::new(1, 6).to_offset(buffer);
+ buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![indoc! {"
+ manual
+ @@ -1,5 +1,5 @@
+ -line 0
+ -line 1
+ +LINE ZERO
+ +LINE ONE
+ line 2
+ line 3
+ line 4
+ "}]
+ );
+
+ // Case 3: Accepted predictions have `predicted` set to true.
+ // Case 5: A manual edit that follows a predicted edit is not merged with the
+ // predicted edit, even if it is nearby.
+ ep_store.update(cx, |ep_store, cx| {
+ buffer.update(cx, |buffer, cx| {
+ let offset = Point::new(2, 0).to_offset(buffer);
+ let end = Point::new(2, 6).to_offset(buffer);
+ buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
+ });
+ ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![
+ indoc! {"
+ manual
+ @@ -1,5 +1,5 @@
+ -line 0
+ -line 1
+ +LINE ZERO
+ +LINE ONE
+ line 2
+ line 3
+ line 4
+ "},
+ indoc! {"
+ predicted
+ @@ -1,6 +1,6 @@
+ LINE ZERO
+ LINE ONE
+ -line 2
+ +LINE TWO
+ line 3
+ line 4
+ line 5
+ "}
+ ]
+ );
+
+ // Case 4: Multiple successive accepted predictions near each other are merged
+ // into one event with `predicted` set to true.
+ ep_store.update(cx, |ep_store, cx| {
+ buffer.update(cx, |buffer, cx| {
+ let offset = Point::new(3, 0).to_offset(buffer);
+ let end = Point::new(3, 6).to_offset(buffer);
+ buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
+ });
+ ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![
+ indoc! {"
+ manual
+ @@ -1,5 +1,5 @@
+ -line 0
+ -line 1
+ +LINE ZERO
+ +LINE ONE
+ line 2
+ line 3
+ line 4
+ "},
+ indoc! {"
+ predicted
+ @@ -1,7 +1,7 @@
+ LINE ZERO
+ LINE ONE
+ -line 2
+ -line 3
+ +LINE TWO
+ +LINE THREE
+ line 4
+ line 5
+ line 6
+ "}
+ ]
+ );
+
+ // Case 5 (continued): A manual edit that follows a predicted edit is not merged
+ // with the predicted edit, even if it is nearby.
+ buffer.update(cx, |buffer, cx| {
+ let offset = Point::new(4, 0).to_offset(buffer);
+ let end = Point::new(4, 6).to_offset(buffer);
+ buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![
+ indoc! {"
+ manual
+ @@ -1,5 +1,5 @@
+ -line 0
+ -line 1
+ +LINE ZERO
+ +LINE ONE
+ line 2
+ line 3
+ line 4
+ "},
+ indoc! {"
+ predicted
+ @@ -1,7 +1,7 @@
+ LINE ZERO
+ LINE ONE
+ -line 2
+ -line 3
+ +LINE TWO
+ +LINE THREE
+ line 4
+ line 5
+ line 6
+ "},
+ indoc! {"
+ manual
+ @@ -2,7 +2,7 @@
+ LINE ONE
+ LINE TWO
+ LINE THREE
+ -line 4
+ +LINE FOUR
+ line 5
+ line 6
+ line 7
+ "}
+ ]
+ );
+
+ // Case 6: If we then perform a manual edit at a *different* location (more than
+ // 8 lines away), then the edits at the prior location can be merged with each
+ // other, even if some are predicted and some are not. `predicted` means all
+ // constituent edits were predicted.
+ buffer.update(cx, |buffer, cx| {
+ let offset = Point::new(14, 0).to_offset(buffer);
+ let end = Point::new(14, 7).to_offset(buffer);
+ buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![
+ indoc! {"
+ manual
+ @@ -1,8 +1,8 @@
+ -line 0
+ -line 1
+ -line 2
+ -line 3
+ -line 4
+ +LINE ZERO
+ +LINE ONE
+ +LINE TWO
+ +LINE THREE
+ +LINE FOUR
+ line 5
+ line 6
+ line 7
+ "},
+ indoc! {"
+ manual
+ @@ -12,4 +12,4 @@
+ line 11
+ line 12
+ line 13
+ -line 14
+ +LINE FOURTEEN
+ "}
+ ]
+ );
+}
+
#[gpui::test]
async fn test_empty_prediction(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
@@ -2261,7 +2516,7 @@ fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
- let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
+ let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
assert_eq!(
diff,