Use sumtree instead of iterator linear search for diff hunks in range

Julia , Max Brunsfeld , and Mikayla Maki created

Co-Authored-By: Max Brunsfeld <max@zed.dev>
Co-Authored-By: Mikayla Maki <mikayla@zed.dev>

Change summary

crates/git/src/diff.rs    | 206 ++++++++++++++++++++++++----------------
crates/text/src/anchor.rs |   2 
2 files changed, 126 insertions(+), 82 deletions(-)

Detailed changes

crates/git/src/diff.rs 🔗

@@ -1,7 +1,7 @@
 use std::ops::Range;
 
 use sum_tree::SumTree;
-use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, ToPoint};
+use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point};
 
 pub use git2 as libgit;
 use libgit::{DiffLineType as GitDiffLineType, DiffOptions as GitOptions, Patch as GitPatch};
@@ -37,7 +37,6 @@ impl sum_tree::Item for DiffHunk<Anchor> {
     fn summary(&self) -> Self::Summary {
         DiffHunkSummary {
             buffer_range: self.buffer_range.clone(),
-            head_range: self.head_byte_range.clone(),
         }
     }
 }
@@ -45,54 +44,17 @@ impl sum_tree::Item for DiffHunk<Anchor> {
 #[derive(Debug, Default, Clone)]
 pub struct DiffHunkSummary {
     buffer_range: Range<Anchor>,
-    head_range: Range<usize>,
 }
 
 impl sum_tree::Summary for DiffHunkSummary {
     type Context = text::BufferSnapshot;
 
-    fn add_summary(&mut self, other: &Self, _: &Self::Context) {
-        self.head_range.start = self.head_range.start.min(other.head_range.start);
-        self.head_range.end = self.head_range.end.max(other.head_range.end);
-    }
-}
-
-#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct HunkHeadEnd(usize);
-
-impl<'a> sum_tree::Dimension<'a, DiffHunkSummary> for HunkHeadEnd {
-    fn add_summary(&mut self, summary: &'a DiffHunkSummary, _: &text::BufferSnapshot) {
-        self.0 = summary.head_range.end;
-    }
-
-    fn from_summary(summary: &'a DiffHunkSummary, _: &text::BufferSnapshot) -> Self {
-        HunkHeadEnd(summary.head_range.end)
-    }
-}
-
-#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct HunkBufferStart(u32);
-
-impl<'a> sum_tree::Dimension<'a, DiffHunkSummary> for HunkBufferStart {
-    fn add_summary(&mut self, summary: &'a DiffHunkSummary, buffer: &text::BufferSnapshot) {
-        self.0 = summary.buffer_range.start.to_point(buffer).row;
-    }
-
-    fn from_summary(summary: &'a DiffHunkSummary, buffer: &text::BufferSnapshot) -> Self {
-        HunkBufferStart(summary.buffer_range.start.to_point(buffer).row)
-    }
-}
-
-#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct HunkBufferEnd(u32);
-
-impl<'a> sum_tree::Dimension<'a, DiffHunkSummary> for HunkBufferEnd {
-    fn add_summary(&mut self, summary: &'a DiffHunkSummary, buffer: &text::BufferSnapshot) {
-        self.0 = summary.buffer_range.end.to_point(buffer).row;
-    }
-
-    fn from_summary(summary: &'a DiffHunkSummary, buffer: &text::BufferSnapshot) -> Self {
-        HunkBufferEnd(summary.buffer_range.end.to_point(buffer).row)
+    fn add_summary(&mut self, other: &Self, buffer: &Self::Context) {
+        self.buffer_range.start = self
+            .buffer_range
+            .start
+            .min(&other.buffer_range.start, buffer);
+        self.buffer_range.end = self.buffer_range.end.max(&other.buffer_range.end, buffer);
     }
 }
 
@@ -115,23 +77,30 @@ impl BufferDiff {
         query_row_range: Range<u32>,
         buffer: &'a BufferSnapshot,
     ) -> impl 'a + Iterator<Item = DiffHunk<u32>> {
-        self.tree.iter().filter_map(move |hunk| {
-            let range = hunk.buffer_range.to_point(&buffer);
-
-            if range.start.row <= query_row_range.end && query_row_range.start <= range.end.row {
-                let end_row = if range.end.column > 0 {
-                    range.end.row + 1
-                } else {
-                    range.end.row
-                };
-
-                Some(DiffHunk {
-                    buffer_range: range.start.row..end_row,
-                    head_byte_range: hunk.head_byte_range.clone(),
-                })
+        let start = buffer.anchor_before(Point::new(query_row_range.start, 0));
+        let end = buffer.anchor_after(Point::new(query_row_range.end, 0));
+
+        let mut cursor = self.tree.filter::<_, DiffHunkSummary>(move |summary| {
+            let before_start = summary.buffer_range.end.cmp(&start, buffer).is_lt();
+            let after_end = summary.buffer_range.start.cmp(&end, buffer).is_gt();
+            !before_start && !after_end
+        });
+
+        std::iter::from_fn(move || {
+            cursor.next(buffer);
+            let hunk = cursor.item()?;
+
+            let range = hunk.buffer_range.to_point(buffer);
+            let end_row = if range.end.column > 0 {
+                range.end.row + 1
             } else {
-                None
-            }
+                range.end.row
+            };
+
+            Some(DiffHunk {
+                buffer_range: range.start.row..end_row,
+                head_byte_range: hunk.head_byte_range.clone(),
+            })
         })
     }
 
@@ -270,7 +239,7 @@ mod tests {
 
         let buffer_text = "
             one
-            hello
+            HELLO
             three
         "
         .unindent();
@@ -278,10 +247,78 @@ mod tests {
         let mut buffer = Buffer::new(0, 0, buffer_text);
         let mut diff = BufferDiff::new();
         smol::block_on(diff.update(&head_text, &buffer));
-        assert_hunks(&diff, &buffer, &head_text, &[(1..2, "two\n")]);
+        assert_hunks(
+            &diff,
+            &buffer,
+            &head_text,
+            &[(1..2, "two\n", "HELLO\n")],
+            None,
+        );
 
         buffer.edit([(0..0, "point five\n")]);
-        assert_hunks(&diff, &buffer, &head_text, &[(2..3, "two\n")]);
+        smol::block_on(diff.update(&head_text, &buffer));
+        assert_hunks(
+            &diff,
+            &buffer,
+            &head_text,
+            &[(0..1, "", "point five\n"), (2..3, "two\n", "HELLO\n")],
+            None,
+        );
+    }
+
+    #[test]
+    fn test_buffer_diff_range() {
+        let head_text = "
+            one
+            two
+            three
+            four
+            five
+            six
+            seven
+            eight
+            nine
+            ten
+        "
+        .unindent();
+
+        let buffer_text = "
+            A
+            one
+            B
+            two
+            C
+            three
+            HELLO
+            four
+            five
+            SIXTEEN
+            seven
+            eight
+            WORLD
+            nine
+
+            ten
+
+        "
+        .unindent();
+
+        let buffer = Buffer::new(0, 0, buffer_text);
+        let mut diff = BufferDiff::new();
+        smol::block_on(diff.update(&head_text, &buffer));
+        assert_eq!(diff.hunks(&buffer).count(), 8);
+
+        assert_hunks(
+            &diff,
+            &buffer,
+            &head_text,
+            &[
+                (6..7, "", "HELLO\n"),
+                (9..10, "six\n", "SIXTEEN\n"),
+                (12..13, "", "WORLD\n"),
+            ],
+            Some(7..12),
+        );
     }
 
     #[track_caller]
@@ -289,23 +326,30 @@ mod tests {
         diff: &BufferDiff,
         buffer: &BufferSnapshot,
         head_text: &str,
-        expected_hunks: &[(Range<u32>, &str)],
+        expected_hunks: &[(Range<u32>, &str, &str)],
+        range: Option<Range<u32>>,
     ) {
-        let hunks = diff.hunks(buffer).collect::<Vec<_>>();
-        assert_eq!(
-            hunks.len(),
-            expected_hunks.len(),
-            "actual hunks are {hunks:#?}"
-        );
-
-        let diff_iter = hunks.iter().enumerate();
-        for ((index, hunk), (expected_range, expected_str)) in diff_iter.zip(expected_hunks) {
-            assert_eq!(&hunk.buffer_range, expected_range, "for hunk {index}");
-            assert_eq!(
-                &head_text[hunk.head_byte_range.clone()],
-                *expected_str,
-                "for hunk {index}"
-            );
-        }
+        let actual_hunks = diff
+            .hunks_in_range(range.unwrap_or(0..u32::MAX), buffer)
+            .map(|hunk| {
+                (
+                    hunk.buffer_range.clone(),
+                    &head_text[hunk.head_byte_range],
+                    buffer
+                        .text_for_range(
+                            Point::new(hunk.buffer_range.start, 0)
+                                ..Point::new(hunk.buffer_range.end, 0),
+                        )
+                        .collect::<String>(),
+                )
+            })
+            .collect::<Vec<_>>();
+
+        let expected_hunks: Vec<_> = expected_hunks
+            .iter()
+            .map(|(r, s, h)| (r.clone(), *s, h.to_string()))
+            .collect();
+
+        assert_eq!(actual_hunks, expected_hunks);
     }
 }

crates/text/src/anchor.rs 🔗

@@ -26,7 +26,7 @@ impl Anchor {
         bias: Bias::Right,
         buffer_id: None,
     };
-
+    
     pub fn cmp(&self, other: &Anchor, buffer: &BufferSnapshot) -> Ordering {
         let fragment_id_comparison = if self.timestamp == other.timestamp {
             Ordering::Equal