diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 33d7d12f1e0eb07ae2e9f13efd7447997c46463a..00238983f800861cb6d94a4e49d8ca3a91d5bbaf 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -450,9 +450,7 @@ mod tests { cx.run_until_parked(); // Verify the external edit was recorded in events - let events = ep_store.update(cx, |store, cx| { - store.edit_history_for_project_with_pause_split_last_event(&project, cx) - }); + let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx)); assert!( matches!( events diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 1ec3c7ac44fc8f592fa094f668b3bfd84245eb5a..8ae47cf61c9d3709a96883dc8c979bc6155ad201 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -231,6 +231,71 @@ pub enum UserActionType { pub struct StoredEvent { pub event: Arc, pub old_snapshot: TextBufferSnapshot, + pub edit_range: Range, +} + +impl StoredEvent { + fn can_merge( + &self, + next_old_event: &&&StoredEvent, + new_snapshot: &TextBufferSnapshot, + last_edit_range: &Range, + ) -> bool { + // Events must be for the same buffer + if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() { + return false; + } + + let a_is_predicted = matches!( + self.event.as_ref(), + zeta_prompt::Event::BufferChange { + predicted: true, + .. + } + ); + let b_is_predicted = matches!( + next_old_event.event.as_ref(), + zeta_prompt::Event::BufferChange { + predicted: true, + .. + } + ); + + // If events come from the same source (both predicted or both manual) then + // we would have coalesced them already. + if a_is_predicted == b_is_predicted { + return false; + } + + let left_range = self.edit_range.to_point(new_snapshot); + let right_range = next_old_event.edit_range.to_point(new_snapshot); + let latest_range = last_edit_range.to_point(&new_snapshot); + + // Events near to the latest edit are not merged if their sources differ. + if lines_between_ranges(&left_range, &latest_range) + .min(lines_between_ranges(&right_range, &latest_range)) + <= CHANGE_GROUPING_LINE_SPAN + { + return false; + } + + // Events that are distant from each other are not merged. + if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN { + return false; + } + + true + } +} + +fn lines_between_ranges(left: &Range, right: &Range) -> u32 { + if left.start > right.end { + return left.start.row - right.end.row; + } + if right.start > left.end { + return right.start.row - left.end.row; + } + 0 } struct ProjectState { @@ -260,18 +325,6 @@ impl ProjectState { } pub fn events(&self, cx: &App) -> Vec { - self.events - .iter() - .cloned() - .chain( - self.last_event - .as_ref() - .and_then(|event| event.finalize(&self.license_detection_watchers, cx)), - ) - .collect() - } - - pub fn events_split_by_pause(&self, cx: &App) -> Vec { self.events .iter() .cloned() @@ -430,6 +483,7 @@ struct LastEvent { old_file: Option>, new_file: Option>, edit_range: Option>, + predicted: bool, snapshot_after_last_editing_pause: Option, last_edit_time: Option, } @@ -454,7 +508,8 @@ impl LastEvent { }) }); - let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?; + let (diff, edit_range) = + compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?; if path == old_path && diff.is_empty() { None @@ -465,9 +520,10 @@ impl LastEvent { path, diff, in_open_source_repo, - // TODO: Actually detect if this edit was predicted or not - predicted: false, + predicted: self.predicted, }), + edit_range: self.new_snapshot.anchor_before(edit_range.start) + ..self.new_snapshot.anchor_before(edit_range.end), old_snapshot: self.old_snapshot.clone(), }) } @@ -484,6 +540,7 @@ impl LastEvent { old_file: self.old_file.clone(), new_file: self.new_file.clone(), edit_range: None, + predicted: self.predicted, snapshot_after_last_editing_pause: None, last_edit_time: self.last_edit_time, }; @@ -494,6 +551,7 @@ impl LastEvent { old_file: self.old_file.clone(), new_file: self.new_file.clone(), edit_range: None, + predicted: self.predicted, snapshot_after_last_editing_pause: None, last_edit_time: self.last_edit_time, }; @@ -505,7 +563,7 @@ impl LastEvent { pub(crate) fn compute_diff_between_snapshots( old_snapshot: &TextBufferSnapshot, new_snapshot: &TextBufferSnapshot, -) -> Option { +) -> Option<(String, Range)> { let edits: Vec> = new_snapshot .edits_since::(&old_snapshot.version) .collect(); @@ -545,7 +603,7 @@ pub(crate) fn compute_diff_between_snapshots( new_context_start_row, ); - Some(diff) + Some((diff, new_start_point..new_end_point)) } fn buffer_path_with_id_fallback( @@ -716,17 +774,6 @@ impl EditPredictionStore { .unwrap_or_default() } - pub fn edit_history_for_project_with_pause_split_last_event( - &self, - project: &Entity, - cx: &App, - ) -> Vec { - self.projects - .get(&project.entity_id()) - .map(|project_state| project_state.events_split_by_pause(cx)) - .unwrap_or_default() - } - pub fn context_for_project<'a>( &'a self, project: &Entity, @@ -1011,7 +1058,7 @@ impl EditPredictionStore { if let language::BufferEvent::Edited = event && let Some(project) = project.upgrade() { - this.report_changes_for_buffer(&buffer, &project, cx); + this.report_changes_for_buffer(&buffer, &project, false, cx); } } }), @@ -1032,6 +1079,7 @@ impl EditPredictionStore { &mut self, buffer: &Entity, project: &Entity, + is_predicted: bool, cx: &mut Context, ) { let project_state = self.get_or_init_project(project, cx); @@ -1065,30 +1113,32 @@ impl EditPredictionStore { last_offset = Some(edit.new.end); } - if num_edits > 0 { - let action_type = match (total_deleted, total_inserted, num_edits) { - (0, ins, n) if ins == n => UserActionType::InsertChar, - (0, _, _) => UserActionType::InsertSelection, - (del, 0, n) if del == n => UserActionType::DeleteChar, - (_, 0, _) => UserActionType::DeleteSelection, - (_, ins, n) if ins == n => UserActionType::InsertChar, - (_, _, _) => UserActionType::InsertSelection, - }; + let Some(edit_range) = edit_range else { + return; + }; - if let Some(offset) = last_offset { - let point = new_snapshot.offset_to_point(offset); - let timestamp_epoch_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - project_state.record_user_action(UserActionRecord { - action_type, - buffer_id: buffer.entity_id(), - line_number: point.row, - offset, - timestamp_epoch_ms, - }); - } + let action_type = match (total_deleted, total_inserted, num_edits) { + (0, ins, n) if ins == n => UserActionType::InsertChar, + (0, _, _) => UserActionType::InsertSelection, + (del, 0, n) if del == n => UserActionType::DeleteChar, + (_, 0, _) => UserActionType::DeleteSelection, + (_, ins, n) if ins == n => UserActionType::InsertChar, + (_, _, _) => UserActionType::InsertSelection, + }; + + if let Some(offset) = last_offset { + let point = new_snapshot.offset_to_point(offset); + let timestamp_epoch_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + project_state.record_user_action(UserActionRecord { + action_type, + buffer_id: buffer.entity_id(), + line_number: point.row, + offset, + timestamp_epoch_ms, + }); } let events = &mut project_state.events; @@ -1099,20 +1149,18 @@ impl EditPredictionStore { == last_event.new_snapshot.remote_id() && old_snapshot.version == last_event.new_snapshot.version; + let prediction_source_changed = is_predicted != last_event.predicted; + let should_coalesce = is_next_snapshot_of_same_buffer - && edit_range + && !prediction_source_changed + && last_event + .edit_range .as_ref() - .zip(last_event.edit_range.as_ref()) - .is_some_and(|(a, b)| { - let a = a.to_point(&new_snapshot); - let b = b.to_point(&new_snapshot); - if a.start > b.end { - a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN - } else if b.start > a.end { - b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN - } else { - true - } + .is_some_and(|last_edit_range| { + lines_between_ranges( + &edit_range.to_point(&new_snapshot), + &last_edit_range.to_point(&new_snapshot), + ) <= CHANGE_GROUPING_LINE_SPAN }); if should_coalesce { @@ -1125,7 +1173,7 @@ impl EditPredictionStore { Some(last_event.new_snapshot.clone()); } - last_event.edit_range = edit_range; + last_event.edit_range = Some(edit_range); last_event.new_snapshot = new_snapshot; last_event.last_edit_time = Some(now); return; @@ -1141,12 +1189,15 @@ impl EditPredictionStore { } } + merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range); + project_state.last_event = Some(LastEvent { old_file, new_file, old_snapshot, new_snapshot, - edit_range, + edit_range: Some(edit_range), + predicted: is_predicted, snapshot_after_last_editing_pause: None, last_edit_time: Some(now), }); @@ -1193,11 +1244,18 @@ impl EditPredictionStore { } fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { - let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + let Some(current_prediction) = self + .projects + .get_mut(&project.entity_id()) + .and_then(|project_state| project_state.current_prediction.take()) + else { return; }; - let Some(current_prediction) = project_state.current_prediction.take() else { + self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx); + + // can't hold &mut project_state ref across report_changes_for_buffer_call + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { return; }; @@ -1719,7 +1777,7 @@ impl EditPredictionStore { self.get_or_init_project(&project, cx); let project_state = self.projects.get(&project.entity_id()).unwrap(); - let stored_events = project_state.events_split_by_pause(cx); + let stored_events = project_state.events(cx); let has_events = !stored_events.is_empty(); let events: Vec> = stored_events.into_iter().map(|e| e.event).collect(); @@ -2219,6 +2277,67 @@ impl EditPredictionStore { } } +fn merge_trailing_events_if_needed( + events: &mut VecDeque, + end_snapshot: &TextBufferSnapshot, + latest_snapshot: &TextBufferSnapshot, + latest_edit_range: &Range, +) { + let mut next_old_event = None; + let mut mergeable_count = 0; + for old_event in events.iter().rev() { + if let Some(next_old_event) = &next_old_event + && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range) + { + break; + } + mergeable_count += 1; + next_old_event = Some(old_event); + } + + if mergeable_count <= 1 { + return; + } + + let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable(); + let oldest_event = events_to_merge.peek().unwrap(); + let oldest_snapshot = oldest_event.old_snapshot.clone(); + + if let Some((diff, edited_range)) = + compute_diff_between_snapshots(&oldest_snapshot, end_snapshot) + { + let merged_event = match oldest_event.event.as_ref() { + zeta_prompt::Event::BufferChange { + old_path, + path, + in_open_source_repo, + .. + } => StoredEvent { + event: Arc::new(zeta_prompt::Event::BufferChange { + old_path: old_path.clone(), + path: path.clone(), + diff, + in_open_source_repo: *in_open_source_repo, + predicted: events_to_merge.all(|e| { + matches!( + e.event.as_ref(), + zeta_prompt::Event::BufferChange { + predicted: true, + .. + } + ) + }), + }), + old_snapshot: oldest_snapshot.clone(), + edit_range: end_snapshot.anchor_before(edited_range.start) + ..end_snapshot.anchor_before(edited_range.end), + }, + }; + events.truncate(events.len() - mergeable_count); + events.push_back(merged_event); + } +} + pub(crate) fn filter_redundant_excerpts( mut related_files: Vec, cursor_path: &Path, diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 19d2532de094b849952ca16c100cf2c8b4a598dc..8e3ebb8a2219ec35e83487efcf449fe81fbd9713 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -356,26 +356,9 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex buffer.edit(vec![(19..19, "!")], None, cx); }); - // Without time-based splitting, there is one event. - let events = ep_store.update(cx, |ep_store, cx| { - ep_store.edit_history_for_project(&project, cx) - }); - assert_eq!(events.len(), 1); - let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref(); - assert_eq!( - diff.as_str(), - indoc! {" - @@ -1,3 +1,3 @@ - Hello! - - - +How are you?! - Bye - "} - ); - // With time-based splitting, there are two distinct events. let events = ep_store.update(cx, |ep_store, cx| { - ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx) + ep_store.edit_history_for_project(&project, cx) }); assert_eq!(events.len(), 2); let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref(); @@ -404,7 +387,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex } #[gpui::test] -async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) { +async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) { let (ep_store, _requests) = init_test_with_fake_client(cx); let fs = FakeFs::new(cx.executor()); @@ -593,6 +576,278 @@ fn render_events(events: &[StoredEvent]) -> String { .join("\n---\n") } +fn render_events_with_predicted(events: &[StoredEvent]) -> Vec { + events + .iter() + .map(|e| { + let zeta_prompt::Event::BufferChange { + diff, predicted, .. + } = e.event.as_ref(); + let prefix = if *predicted { "predicted" } else { "manual" }; + format!("{}\n{}", prefix, diff) + }) + .collect() +} + +#[gpui::test] +async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) { + let (ep_store, _requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx); + }); + + // Case 1: Manual edits have `predicted` set to false. + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(0..6, "LINE ZERO")], None, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + + assert_eq!( + render_events_with_predicted(&events), + vec![indoc! {" + manual + @@ -1,4 +1,4 @@ + -line 0 + +LINE ZERO + line 1 + line 2 + line 3 + "}] + ); + + // Case 2: Multiple successive manual edits near each other are merged into one + // event with `predicted` set to false. + buffer.update(cx, |buffer, cx| { + let offset = Point::new(1, 0).to_offset(buffer); + let end = Point::new(1, 6).to_offset(buffer); + buffer.edit(vec![(offset..end, "LINE ONE")], None, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!( + render_events_with_predicted(&events), + vec![indoc! {" + manual + @@ -1,5 +1,5 @@ + -line 0 + -line 1 + +LINE ZERO + +LINE ONE + line 2 + line 3 + line 4 + "}] + ); + + // Case 3: Accepted predictions have `predicted` set to true. + // Case 5: A manual edit that follows a predicted edit is not merged with the + // predicted edit, even if it is nearby. + ep_store.update(cx, |ep_store, cx| { + buffer.update(cx, |buffer, cx| { + let offset = Point::new(2, 0).to_offset(buffer); + let end = Point::new(2, 6).to_offset(buffer); + buffer.edit(vec![(offset..end, "LINE TWO")], None, cx); + }); + ep_store.report_changes_for_buffer(&buffer, &project, true, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!( + render_events_with_predicted(&events), + vec![ + indoc! {" + manual + @@ -1,5 +1,5 @@ + -line 0 + -line 1 + +LINE ZERO + +LINE ONE + line 2 + line 3 + line 4 + "}, + indoc! {" + predicted + @@ -1,6 +1,6 @@ + LINE ZERO + LINE ONE + -line 2 + +LINE TWO + line 3 + line 4 + line 5 + "} + ] + ); + + // Case 4: Multiple successive accepted predictions near each other are merged + // into one event with `predicted` set to true. + ep_store.update(cx, |ep_store, cx| { + buffer.update(cx, |buffer, cx| { + let offset = Point::new(3, 0).to_offset(buffer); + let end = Point::new(3, 6).to_offset(buffer); + buffer.edit(vec![(offset..end, "LINE THREE")], None, cx); + }); + ep_store.report_changes_for_buffer(&buffer, &project, true, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!( + render_events_with_predicted(&events), + vec![ + indoc! {" + manual + @@ -1,5 +1,5 @@ + -line 0 + -line 1 + +LINE ZERO + +LINE ONE + line 2 + line 3 + line 4 + "}, + indoc! {" + predicted + @@ -1,7 +1,7 @@ + LINE ZERO + LINE ONE + -line 2 + -line 3 + +LINE TWO + +LINE THREE + line 4 + line 5 + line 6 + "} + ] + ); + + // Case 5 (continued): A manual edit that follows a predicted edit is not merged + // with the predicted edit, even if it is nearby. + buffer.update(cx, |buffer, cx| { + let offset = Point::new(4, 0).to_offset(buffer); + let end = Point::new(4, 6).to_offset(buffer); + buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!( + render_events_with_predicted(&events), + vec![ + indoc! {" + manual + @@ -1,5 +1,5 @@ + -line 0 + -line 1 + +LINE ZERO + +LINE ONE + line 2 + line 3 + line 4 + "}, + indoc! {" + predicted + @@ -1,7 +1,7 @@ + LINE ZERO + LINE ONE + -line 2 + -line 3 + +LINE TWO + +LINE THREE + line 4 + line 5 + line 6 + "}, + indoc! {" + manual + @@ -2,7 +2,7 @@ + LINE ONE + LINE TWO + LINE THREE + -line 4 + +LINE FOUR + line 5 + line 6 + line 7 + "} + ] + ); + + // Case 6: If we then perform a manual edit at a *different* location (more than + // 8 lines away), then the edits at the prior location can be merged with each + // other, even if some are predicted and some are not. `predicted` means all + // constituent edits were predicted. + buffer.update(cx, |buffer, cx| { + let offset = Point::new(14, 0).to_offset(buffer); + let end = Point::new(14, 7).to_offset(buffer); + buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx); + }); + + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!( + render_events_with_predicted(&events), + vec![ + indoc! {" + manual + @@ -1,8 +1,8 @@ + -line 0 + -line 1 + -line 2 + -line 3 + -line 4 + +LINE ZERO + +LINE ONE + +LINE TWO + +LINE THREE + +LINE FOUR + line 5 + line 6 + line 7 + "}, + indoc! {" + manual + @@ -12,4 +12,4 @@ + line 11 + line 12 + line 13 + -line 14 + +LINE FOURTEEN + "} + ] + ); +} + #[gpui::test] async fn test_empty_prediction(cx: &mut TestAppContext) { let (ep_store, mut requests) = init_test_with_fake_client(cx); @@ -2261,7 +2516,7 @@ fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) { let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot()); - let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap(); + let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap(); assert_eq!( diff, diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index b684aa48512ff8da25ab4196fe73f8cf8c5412b4..774bc19af304d36cad43aedbfe088b4daca52d62 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -153,9 +153,7 @@ fn capture_example_as_markdown( .read(cx) .text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?; let ep_store = EditPredictionStore::try_global(cx)?; - let events = ep_store.update(cx, |store, cx| { - store.edit_history_for_project_with_pause_split_last_event(&project, cx) - }); + let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx)); let example = capture_example( project.clone(), buffer, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 4c5cc94b4ff1de96ba9099477d5872de07667fc3..af6478a6199e15ece663b1f8b68240a8276950b2 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -7987,10 +7987,6 @@ impl Editor { match granularity { EditPredictionGranularity::Full => { - if let Some(provider) = self.edit_prediction_provider() { - provider.accept(cx); - } - let transaction_id_prev = self.buffer.read(cx).last_transaction_id(cx); // Compute fallback cursor position BEFORE applying the edit, @@ -8004,6 +8000,10 @@ impl Editor { buffer.edit(edits.iter().cloned(), None, cx) }); + if let Some(provider) = self.edit_prediction_provider() { + provider.accept(cx); + } + // Resolve cursor position after the edit is applied let cursor_target = if let Some((anchor, offset)) = cursor_position { // The anchor tracks through the edit, then we add the offset