implement and test prev_diff_hunk for the inverted case

Cole Miller created

Change summary

crates/buffer_diff/src/buffer_diff.rs         | 33 ++++--
crates/editor/src/split.rs                    |  7 -
crates/multi_buffer/src/multi_buffer.rs       | 99 +++++++++++++-------
crates/multi_buffer/src/multi_buffer_tests.rs | 62 +++++++++++-
4 files changed, 142 insertions(+), 59 deletions(-)

Detailed changes

crates/buffer_diff/src/buffer_diff.rs 🔗

@@ -226,7 +226,12 @@ impl BufferDiffSnapshot {
         range: Range<Anchor>,
         buffer: &'a text::BufferSnapshot,
     ) -> impl 'a + Iterator<Item = DiffHunk> {
-        self.inner.hunks_intersecting_range_rev(range, buffer)
+        let filter = move |summary: &DiffHunkSummary| {
+            let before_start = summary.buffer_range.end.cmp(&range.start, buffer).is_lt();
+            let after_end = summary.buffer_range.start.cmp(&range.end, buffer).is_gt();
+            !before_start && !after_end
+        };
+        self.inner.hunks_intersecting_range_rev_impl(filter, buffer)
     }
 
     pub fn hunks_intersecting_base_text_range<'a>(
@@ -244,6 +249,20 @@ impl BufferDiffSnapshot {
             .hunks_intersecting_range_impl(filter, main_buffer, unstaged_counterpart)
     }
 
+    pub fn hunks_intersecting_base_text_range_rev<'a>(
+        &'a self,
+        range: Range<usize>,
+        main_buffer: &'a text::BufferSnapshot,
+    ) -> impl 'a + Iterator<Item = DiffHunk> {
+        let filter = move |summary: &DiffHunkSummary| {
+            let before_start = summary.diff_base_byte_range.end.cmp(&range.start).is_lt();
+            let after_end = summary.diff_base_byte_range.start.cmp(&range.end).is_gt();
+            !before_start && !after_end
+        };
+        self.inner
+            .hunks_intersecting_range_rev_impl(filter, main_buffer)
+    }
+
     pub fn hunks<'a>(
         &'a self,
         buffer_snapshot: &'a text::BufferSnapshot,
@@ -705,18 +724,12 @@ impl BufferDiffInner {
         })
     }
 
-    fn hunks_intersecting_range_rev<'a>(
+    fn hunks_intersecting_range_rev_impl<'a>(
         &'a self,
-        range: Range<Anchor>,
+        filter: impl 'a + Fn(&DiffHunkSummary) -> bool,
         buffer: &'a text::BufferSnapshot,
     ) -> impl 'a + Iterator<Item = DiffHunk> {
-        let mut cursor = self
-            .hunks
-            .filter::<_, DiffHunkSummary>(buffer, move |summary| {
-                let before_start = summary.buffer_range.end.cmp(&range.start, buffer).is_lt();
-                let after_end = summary.buffer_range.start.cmp(&range.end, buffer).is_gt();
-                !before_start && !after_end
-            });
+        let mut cursor = self.hunks.filter::<_, DiffHunkSummary>(buffer, filter);
 
         iter::from_fn(move || {
             cursor.prev();

crates/editor/src/split.rs 🔗

@@ -358,12 +358,7 @@ impl SecondaryEditor {
                     new,
                     cx,
                 );
-                buffer.add_inverted_diff(
-                    base_text_buffer_snapshot.remote_id(),
-                    diff,
-                    main_buffer,
-                    cx,
-                );
+                buffer.add_inverted_diff(diff, main_buffer, cx);
             })
         });
     }

crates/multi_buffer/src/multi_buffer.rs 🔗

@@ -524,7 +524,7 @@ impl DiffState {
             diff: self.diff.read(cx).snapshot(cx),
             main_buffer: self.main_buffer.as_ref().and_then(|main_buffer| {
                 main_buffer
-                    .read_with(cx, |main_buffer, _| main_buffer.snapshot())
+                    .read_with(cx, |main_buffer, _| main_buffer.text_snapshot())
                     .ok()
             }),
         }
@@ -534,7 +534,7 @@ impl DiffState {
 #[derive(Clone)]
 struct DiffStateSnapshot {
     diff: BufferDiffSnapshot,
-    main_buffer: Option<BufferSnapshot>,
+    main_buffer: Option<text::BufferSnapshot>,
 }
 
 impl std::ops::Deref for DiffStateSnapshot {
@@ -568,7 +568,6 @@ impl DiffState {
 
     fn new_inverted(
         diff: Entity<BufferDiff>,
-        base_text_buffer_id: BufferId,
         main_buffer: Entity<Buffer>,
         cx: &mut Context<MultiBuffer>,
     ) -> Self {
@@ -585,19 +584,15 @@ impl DiffState {
                             this.inverted_buffer_diff_changed(
                                 diff,
                                 base_text_changed_range,
-                                base_text_buffer_id,
                                 main_buffer.clone(),
                                 cx,
                             )
                         }
                         cx.emit(Event::BufferDiffChanged);
                     }
-                    BufferDiffEvent::LanguageChanged => this.inverted_buffer_diff_language_changed(
-                        base_text_buffer_id,
-                        diff,
-                        main_buffer.clone(),
-                        cx,
-                    ),
+                    BufferDiffEvent::LanguageChanged => {
+                        this.inverted_buffer_diff_language_changed(diff, main_buffer.clone(), cx)
+                    }
                     _ => {}
                 }
             }),
@@ -2290,16 +2285,16 @@ impl MultiBuffer {
 
     fn inverted_buffer_diff_language_changed(
         &mut self,
-        base_text_buffer_id: BufferId,
         diff: Entity<BufferDiff>,
         main_buffer: WeakEntity<Buffer>,
         cx: &mut Context<Self>,
     ) {
+        let base_text_buffer_id = diff.read(cx).base_text(cx).remote_id();
         let diff = diff.read(cx);
         let diff = DiffStateSnapshot {
             diff: diff.snapshot(cx),
             main_buffer: main_buffer
-                .update(cx, |main_buffer, _| main_buffer.snapshot())
+                .update(cx, |main_buffer, _| main_buffer.text_snapshot())
                 .ok(),
         };
         self.snapshot
@@ -2357,13 +2352,13 @@ impl MultiBuffer {
         &mut self,
         diff: Entity<BufferDiff>,
         diff_change_range: Range<usize>,
-        base_text_buffer_id: BufferId,
         main_buffer: WeakEntity<Buffer>,
         cx: &mut Context<Self>,
     ) {
         self.sync_mut(cx);
 
         let diff = diff.read(cx);
+        let base_text_buffer_id = diff.base_text(cx).remote_id();
         let Some(buffer_state) = self.buffers.get(&base_text_buffer_id) else {
             return;
         };
@@ -2372,7 +2367,7 @@ impl MultiBuffer {
         let new_diff = DiffStateSnapshot {
             diff: diff.snapshot(cx),
             main_buffer: main_buffer
-                .update(cx, |main_buffer, _| main_buffer.snapshot())
+                .update(cx, |main_buffer, _| main_buffer.text_snapshot())
                 .ok(),
         };
         let mut snapshot = self.snapshot.get_mut();
@@ -2563,24 +2558,23 @@ impl MultiBuffer {
 
     pub fn add_inverted_diff(
         &mut self,
-        base_text_buffer_id: BufferId,
         diff: Entity<BufferDiff>,
         main_buffer: Entity<Buffer>,
         cx: &mut Context<Self>,
     ) {
         debug_assert!(self.diffs.values().all(|diff| diff.main_buffer.is_some()));
 
+        let base_text_buffer_id = diff.read(cx).base_text(cx).remote_id();
         let diff_change_range = 0..diff.read(cx).base_text(cx).len();
         self.inverted_buffer_diff_changed(
             diff.clone(),
             diff_change_range,
-            base_text_buffer_id,
             main_buffer.downgrade(),
             cx,
         );
         self.diffs.insert(
             base_text_buffer_id,
-            DiffState::new_inverted(diff, base_text_buffer_id, main_buffer, cx),
+            DiffState::new_inverted(diff, main_buffer, cx),
         );
     }
 
@@ -4183,7 +4177,6 @@ impl MultiBufferSnapshot {
         })
     }
 
-    // FIXME need to make this work with inverted diffs
     pub fn diff_hunk_before<T: ToOffset>(&self, position: T) -> Option<MultiBufferRow> {
         let offset = position.to_offset(self);
 
@@ -4197,26 +4190,42 @@ impl MultiBufferSnapshot {
         cursor.seek_to_start_of_current_excerpt();
         let excerpt = cursor.excerpt()?;
 
+        let excerpt_start = excerpt.range.context.start.to_offset(&excerpt.buffer);
         let excerpt_end = excerpt.range.context.end.to_offset(&excerpt.buffer);
         let current_position = self
             .anchor_before(offset)
             .text_anchor
             .to_offset(&excerpt.buffer);
-        let excerpt_end = excerpt
-            .buffer
-            .anchor_before(excerpt_end.min(current_position));
 
         if let Some(diff) = self.diffs.get(&excerpt.buffer_id) {
-            for hunk in diff.hunks_intersecting_range_rev(
-                excerpt.range.context.start..excerpt_end,
-                &excerpt.buffer,
-            ) {
-                let hunk_end = hunk.buffer_range.end.to_offset(&excerpt.buffer);
-                if hunk_end >= current_position {
-                    continue;
+            if let Some(main_buffer) = diff.main_buffer.as_ref() {
+                for hunk in diff.hunks_intersecting_base_text_range_rev(
+                    excerpt_start..excerpt_end,
+                    &main_buffer,
+                ) {
+                    if hunk.diff_base_byte_range.end >= current_position {
+                        continue;
+                    }
+                    let hunk_start = excerpt.buffer.anchor_after(hunk.diff_base_byte_range.start);
+                    let start = Anchor::in_buffer(excerpt.id, hunk_start).to_point(self);
+                    return Some(MultiBufferRow(start.row));
+                }
+            } else {
+                let excerpt_end = excerpt
+                    .buffer
+                    .anchor_before(excerpt_end.min(current_position));
+                for hunk in diff.hunks_intersecting_range_rev(
+                    excerpt.range.context.start..excerpt_end,
+                    &excerpt.buffer,
+                ) {
+                    let hunk_end = hunk.buffer_range.end.to_offset(&excerpt.buffer);
+                    if hunk_end >= current_position {
+                        continue;
+                    }
+                    let start =
+                        Anchor::in_buffer(excerpt.id, hunk.buffer_range.start).to_point(self);
+                    return Some(MultiBufferRow(start.row));
                 }
-                let start = Anchor::in_buffer(excerpt.id, hunk.buffer_range.start).to_point(self);
-                return Some(MultiBufferRow(start.row));
             }
         }
 
@@ -4227,13 +4236,29 @@ impl MultiBufferSnapshot {
             let Some(diff) = self.diffs.get(&excerpt.buffer_id) else {
                 continue;
             };
-            let mut hunks =
-                diff.hunks_intersecting_range_rev(excerpt.range.context.clone(), &excerpt.buffer);
-            let Some(hunk) = hunks.next() else {
-                continue;
-            };
-            let start = Anchor::in_buffer(excerpt.id, hunk.buffer_range.start).to_point(self);
-            return Some(MultiBufferRow(start.row));
+            if let Some(main_buffer) = diff.main_buffer.as_ref() {
+                let Some(hunk) = diff
+                    .hunks_intersecting_base_text_range_rev(
+                        excerpt.range.context.to_offset(&excerpt.buffer),
+                        main_buffer,
+                    )
+                    .next()
+                else {
+                    continue;
+                };
+                let hunk_start = excerpt.buffer.anchor_after(hunk.diff_base_byte_range.start);
+                let start = Anchor::in_buffer(excerpt.id, hunk_start).to_point(self);
+                return Some(MultiBufferRow(start.row));
+            } else {
+                let Some(hunk) = diff
+                    .hunks_intersecting_range_rev(excerpt.range.context.clone(), &excerpt.buffer)
+                    .next()
+                else {
+                    continue;
+                };
+                let start = Anchor::in_buffer(excerpt.id, hunk.buffer_range.start).to_point(self);
+                return Some(MultiBufferRow(start.row));
+            }
         }
     }
 

crates/multi_buffer/src/multi_buffer_tests.rs 🔗

@@ -473,6 +473,61 @@ async fn test_diff_hunks_in_range(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test]
+async fn test_inverted_diff_hunks_in_range(cx: &mut TestAppContext) {
+    let base_text = "one\ntwo\nthree\nfour\nfive\nsix\nseven\neight\n";
+    let text = "one\nTHREE\nfour\nseven\nEIGHT\nNINE\n";
+    let buffer = cx.new(|cx| Buffer::local(text, cx));
+    let diff = cx
+        .new(|cx| BufferDiff::new_with_base_text(base_text, &buffer.read(cx).text_snapshot(), cx));
+    let base_text_buffer = diff.read_with(cx, |diff, _| diff.base_text_buffer());
+    let multibuffer = cx.new(|cx| MultiBuffer::singleton(base_text_buffer.clone(), cx));
+    let (mut snapshot, mut subscription) = multibuffer.update(cx, |multibuffer, cx| {
+        (multibuffer.snapshot(cx), multibuffer.subscribe())
+    });
+
+    multibuffer.update(cx, |multibuffer, cx| {
+        multibuffer.add_inverted_diff(diff, buffer.clone(), cx);
+        multibuffer.expand_diff_hunks(vec![Anchor::min()..Anchor::max()], cx);
+    });
+
+    assert_new_snapshot(
+        &multibuffer,
+        &mut snapshot,
+        &mut subscription,
+        cx,
+        indoc! {
+            "  one
+             - two
+             - three
+               four
+             - five
+             - six
+               seven
+             - eight
+            "
+        },
+    );
+
+    assert_eq!(
+        snapshot
+            .diff_hunks_in_range(Point::new(1, 0)..Point::MAX)
+            .map(|hunk| hunk.row_range.start.0..hunk.row_range.end.0)
+            .collect::<Vec<_>>(),
+        vec![1..3, 4..6, 7..8]
+    );
+
+    assert_eq!(snapshot.diff_hunk_before(Point::new(1, 1)), None,);
+    assert_eq!(
+        snapshot.diff_hunk_before(Point::new(7, 0)),
+        Some(MultiBufferRow(4))
+    );
+    assert_eq!(
+        snapshot.diff_hunk_before(Point::new(4, 0)),
+        Some(MultiBufferRow(1))
+    );
+}
+
 #[gpui::test]
 async fn test_editing_text_in_diff_hunks(cx: &mut TestAppContext) {
     let base_text = "one\ntwo\nfour\nfive\nsix\nseven\n";
@@ -3697,12 +3752,7 @@ async fn test_inverted_diff(cx: &mut TestAppContext) {
     let multibuffer = cx.new(|cx| {
         let mut multibuffer = MultiBuffer::singleton(base_text_buffer.clone(), cx);
         multibuffer.set_all_diff_hunks_expanded(cx);
-        multibuffer.add_inverted_diff(
-            base_text_buffer.read(cx).remote_id(),
-            diff.clone(),
-            buffer.clone(),
-            cx,
-        );
+        multibuffer.add_inverted_diff(diff.clone(), buffer.clone(), cx);
         multibuffer
     });