git: Implement branch diff line counts efficiently (#52582)

Cole Miller created

Previously we were iterating over all hunks across all diffs on every
frame. Now we can read off the required information as a `SumTree`
summary in constant time.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [ ] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- git: Fixed bad performance in large branch diffs.

Change summary

crates/buffer_diff/src/buffer_diff.rs   |  34 +++++
crates/git_ui/src/commit_view.rs        |  33 -----
crates/git_ui/src/project_diff.rs       |  33 -----
crates/multi_buffer/src/anchor.rs       |   9 -
crates/multi_buffer/src/multi_buffer.rs | 160 +++++++++++++++++++-------
5 files changed, 153 insertions(+), 116 deletions(-)

Detailed changes

crates/buffer_diff/src/buffer_diff.rs 🔗

@@ -115,6 +115,7 @@ pub struct DiffHunk {
 struct InternalDiffHunk {
     buffer_range: Range<Anchor>,
     diff_base_byte_range: Range<usize>,
+    diff_base_point_range: Range<Point>,
     base_word_diffs: Vec<Range<usize>>,
     buffer_word_diffs: Vec<Range<Anchor>>,
 }
@@ -131,15 +132,25 @@ struct PendingHunk {
 pub struct DiffHunkSummary {
     buffer_range: Range<Anchor>,
     diff_base_byte_range: Range<usize>,
+    added_rows: u32,
+    removed_rows: u32,
 }
 
 impl sum_tree::Item for InternalDiffHunk {
     type Summary = DiffHunkSummary;
 
-    fn summary(&self, _cx: &text::BufferSnapshot) -> Self::Summary {
+    fn summary(&self, buffer: &text::BufferSnapshot) -> Self::Summary {
+        let buffer_start = self.buffer_range.start.to_point(buffer);
+        let buffer_end = self.buffer_range.end.to_point(buffer);
         DiffHunkSummary {
             buffer_range: self.buffer_range.clone(),
             diff_base_byte_range: self.diff_base_byte_range.clone(),
+            added_rows: buffer_end.row.saturating_sub(buffer_start.row),
+            removed_rows: self
+                .diff_base_point_range
+                .end
+                .row
+                .saturating_sub(self.diff_base_point_range.start.row),
         }
     }
 }
@@ -151,6 +162,8 @@ impl sum_tree::Item for PendingHunk {
         DiffHunkSummary {
             buffer_range: self.buffer_range.clone(),
             diff_base_byte_range: self.diff_base_byte_range.clone(),
+            added_rows: 0,
+            removed_rows: 0,
         }
     }
 }
@@ -162,6 +175,8 @@ impl sum_tree::Summary for DiffHunkSummary {
         DiffHunkSummary {
             buffer_range: Anchor::MIN..Anchor::MIN,
             diff_base_byte_range: 0..0,
+            added_rows: 0,
+            removed_rows: 0,
         }
     }
 
@@ -180,6 +195,9 @@ impl sum_tree::Summary for DiffHunkSummary {
             .diff_base_byte_range
             .end
             .max(other.diff_base_byte_range.end);
+
+        self.added_rows += other.added_rows;
+        self.removed_rows += other.removed_rows;
     }
 }
 
@@ -234,6 +252,11 @@ impl BufferDiffSnapshot {
         self.inner.hunks.is_empty()
     }
 
+    pub fn changed_row_counts(&self) -> (u32, u32) {
+        let summary = self.inner.hunks.summary();
+        (summary.added_rows, summary.removed_rows)
+    }
+
     pub fn base_text_string(&self) -> Option<String> {
         self.inner
             .base_text_exists
@@ -1120,6 +1143,8 @@ fn compute_hunks(
                 InternalDiffHunk {
                     buffer_range: buffer.anchor_before(0)..buffer.anchor_before(0),
                     diff_base_byte_range: 0..diff_base.len() - 1,
+                    diff_base_point_range: Point::new(0, 0)
+                        ..diff_base_rope.offset_to_point(diff_base.len() - 1),
                     base_word_diffs: Vec::default(),
                     buffer_word_diffs: Vec::default(),
                 },
@@ -1147,6 +1172,7 @@ fn compute_hunks(
             InternalDiffHunk {
                 buffer_range: Anchor::min_max_range_for_buffer(buffer.remote_id()),
                 diff_base_byte_range: 0..0,
+                diff_base_point_range: Point::new(0, 0)..Point::new(0, 0),
                 base_word_diffs: Vec::default(),
                 buffer_word_diffs: Vec::default(),
             },
@@ -1460,7 +1486,9 @@ fn process_patch_hunk(
 
     InternalDiffHunk {
         buffer_range,
-        diff_base_byte_range,
+        diff_base_byte_range: diff_base_byte_range.clone(),
+        diff_base_point_range: diff_base.offset_to_point(diff_base_byte_range.start)
+            ..diff_base.offset_to_point(diff_base_byte_range.end),
         base_word_diffs,
         buffer_word_diffs,
     }
@@ -1565,6 +1593,8 @@ impl BufferDiff {
             self.inner.pending_hunks = SumTree::from_summary(DiffHunkSummary {
                 buffer_range: Anchor::min_min_range_for_buffer(self.buffer_id),
                 diff_base_byte_range: 0..0,
+                added_rows: 0,
+                removed_rows: 0,
             });
             let changed_range = Some(Anchor::min_max_range_for_buffer(self.buffer_id));
             let base_text_range = Some(0..self.base_text(cx).len());

crates/git_ui/src/commit_view.rs 🔗

@@ -414,38 +414,7 @@ impl CommitView {
     }
 
     fn calculate_changed_lines(&self, cx: &App) -> (u32, u32) {
-        let snapshot = self.multibuffer.read(cx).snapshot(cx);
-        let mut total_additions = 0u32;
-        let mut total_deletions = 0u32;
-
-        let mut seen_buffers = std::collections::HashSet::new();
-        for (_, buffer, _) in snapshot.excerpts() {
-            let buffer_id = buffer.remote_id();
-            if !seen_buffers.insert(buffer_id) {
-                continue;
-            }
-
-            let Some(diff) = snapshot.diff_for_buffer_id(buffer_id) else {
-                continue;
-            };
-
-            let base_text = diff.base_text();
-
-            for hunk in diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer) {
-                let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row);
-                total_additions += added_rows;
-
-                let base_start = base_text
-                    .offset_to_point(hunk.diff_base_byte_range.start)
-                    .row;
-                let base_end = base_text.offset_to_point(hunk.diff_base_byte_range.end).row;
-                let deleted_rows = base_end.saturating_sub(base_start);
-
-                total_deletions += deleted_rows;
-            }
-        }
-
-        (total_additions, total_deletions)
+        self.multibuffer.read(cx).snapshot(cx).total_changed_lines()
     }
 
     fn render_header(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {

crates/git_ui/src/project_diff.rs 🔗

@@ -544,38 +544,7 @@ impl ProjectDiff {
     }
 
     pub fn calculate_changed_lines(&self, cx: &App) -> (u32, u32) {
-        let snapshot = self.multibuffer.read(cx).snapshot(cx);
-        let mut total_additions = 0u32;
-        let mut total_deletions = 0u32;
-
-        let mut seen_buffers = HashSet::default();
-        for (_, buffer, _) in snapshot.excerpts() {
-            let buffer_id = buffer.remote_id();
-            if !seen_buffers.insert(buffer_id) {
-                continue;
-            }
-
-            let Some(diff) = snapshot.diff_for_buffer_id(buffer_id) else {
-                continue;
-            };
-
-            let base_text = diff.base_text();
-
-            for hunk in diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer) {
-                let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row);
-                total_additions += added_rows;
-
-                let base_start = base_text
-                    .offset_to_point(hunk.diff_base_byte_range.start)
-                    .row;
-                let base_end = base_text.offset_to_point(hunk.diff_base_byte_range.end).row;
-                let deleted_rows = base_end.saturating_sub(base_start);
-
-                total_deletions += deleted_rows;
-            }
-        }
-
-        (total_additions, total_deletions)
+        self.multibuffer.read(cx).snapshot(cx).total_changed_lines()
     }
 
     /// Returns the total count of review comments across all hunks/files.

crates/multi_buffer/src/anchor.rs 🔗

@@ -119,8 +119,7 @@ impl Anchor {
             }
             if (self.diff_base_anchor.is_some() || other.diff_base_anchor.is_some())
                 && let Some(base_text) = snapshot
-                    .diffs
-                    .get(&excerpt.buffer_id)
+                    .diff_state(excerpt.buffer_id)
                     .map(|diff| diff.base_text())
             {
                 let self_anchor = self.diff_base_anchor.filter(|a| a.is_valid(base_text));
@@ -155,8 +154,7 @@ impl Anchor {
                 text_anchor: self.text_anchor.bias_left(&excerpt.buffer),
                 diff_base_anchor: self.diff_base_anchor.map(|a| {
                     if let Some(base_text) = snapshot
-                        .diffs
-                        .get(&excerpt.buffer_id)
+                        .diff_state(excerpt.buffer_id)
                         .map(|diff| diff.base_text())
                         && a.is_valid(&base_text)
                     {
@@ -178,8 +176,7 @@ impl Anchor {
                 text_anchor: self.text_anchor.bias_right(&excerpt.buffer),
                 diff_base_anchor: self.diff_base_anchor.map(|a| {
                     if let Some(base_text) = snapshot
-                        .diffs
-                        .get(&excerpt.buffer_id)
+                        .diff_state(excerpt.buffer_id)
                         .map(|diff| diff.base_text())
                         && a.is_valid(&base_text)
                     {

crates/multi_buffer/src/multi_buffer.rs 🔗

@@ -532,8 +532,9 @@ struct DiffState {
 }
 
 impl DiffState {
-    fn snapshot(&self, cx: &App) -> DiffStateSnapshot {
+    fn snapshot(&self, buffer_id: BufferId, cx: &App) -> DiffStateSnapshot {
         DiffStateSnapshot {
+            buffer_id,
             diff: self.diff.read(cx).snapshot(cx),
             main_buffer: self.main_buffer.as_ref().map(|b| b.read(cx).snapshot()),
         }
@@ -542,6 +543,7 @@ impl DiffState {
 
 #[derive(Clone)]
 struct DiffStateSnapshot {
+    buffer_id: BufferId,
     diff: BufferDiffSnapshot,
     main_buffer: Option<language::BufferSnapshot>,
 }
@@ -554,6 +556,77 @@ impl std::ops::Deref for DiffStateSnapshot {
     }
 }
 
+#[derive(Clone, Debug, Default)]
+struct DiffStateSummary {
+    max_buffer_id: Option<BufferId>,
+    added_rows: u32,
+    removed_rows: u32,
+}
+
+impl sum_tree::ContextLessSummary for DiffStateSummary {
+    fn zero() -> Self {
+        Self::default()
+    }
+
+    fn add_summary(&mut self, other: &Self) {
+        self.max_buffer_id = std::cmp::max(self.max_buffer_id, other.max_buffer_id);
+        self.added_rows += other.added_rows;
+        self.removed_rows += other.removed_rows;
+    }
+}
+
+impl sum_tree::Item for DiffStateSnapshot {
+    type Summary = DiffStateSummary;
+
+    fn summary(&self, _cx: ()) -> DiffStateSummary {
+        let (added_rows, removed_rows) = self.diff.changed_row_counts();
+        DiffStateSummary {
+            max_buffer_id: Some(self.buffer_id),
+            added_rows,
+            removed_rows,
+        }
+    }
+}
+
+impl sum_tree::KeyedItem for DiffStateSnapshot {
+    type Key = Option<BufferId>;
+
+    fn key(&self) -> Option<BufferId> {
+        Some(self.buffer_id)
+    }
+}
+
+impl<'a> Dimension<'a, DiffStateSummary> for Option<BufferId> {
+    fn zero(_cx: ()) -> Self {
+        None
+    }
+
+    fn add_summary(&mut self, summary: &DiffStateSummary, _cx: ()) {
+        *self = std::cmp::max(*self, summary.max_buffer_id);
+    }
+}
+
+fn find_diff_state(
+    diffs: &SumTree<DiffStateSnapshot>,
+    buffer_id: BufferId,
+) -> Option<&DiffStateSnapshot> {
+    let key = Some(buffer_id);
+    let (.., item) = diffs.find::<Option<BufferId>, _>((), &key, Bias::Left);
+    item.filter(|entry| entry.buffer_id == buffer_id)
+}
+
+fn remove_diff_state(diffs: &mut SumTree<DiffStateSnapshot>, buffer_id: BufferId) {
+    let key = Some(buffer_id);
+    let mut cursor = diffs.cursor::<Option<BufferId>>(());
+    let mut new_tree = cursor.slice(&key, Bias::Left);
+    if key == cursor.end() {
+        cursor.next();
+    }
+    new_tree.append(cursor.suffix(), ());
+    drop(cursor);
+    *diffs = new_tree;
+}
+
 impl DiffState {
     fn new(diff: Entity<BufferDiff>, cx: &mut Context<MultiBuffer>) -> Self {
         DiffState {
@@ -626,7 +699,7 @@ impl DiffState {
 pub struct MultiBufferSnapshot {
     excerpts: SumTree<Excerpt>,
     buffer_locators: TreeMap<BufferId, Arc<[Locator]>>,
-    diffs: TreeMap<BufferId, DiffStateSnapshot>,
+    diffs: SumTree<DiffStateSnapshot>,
     diff_transforms: SumTree<DiffTransform>,
     excerpt_ids: SumTree<ExcerptIdMapping>,
     replaced_excerpts: Arc<HashMap<ExcerptId, ExcerptId>>,
@@ -995,7 +1068,7 @@ pub struct MultiBufferChunks<'a> {
     excerpts: Cursor<'a, 'static, Excerpt, ExcerptOffset>,
     diff_transforms:
         Cursor<'a, 'static, DiffTransform, Dimensions<MultiBufferOffset, ExcerptOffset>>,
-    diffs: &'a TreeMap<BufferId, DiffStateSnapshot>,
+    diffs: &'a SumTree<DiffStateSnapshot>,
     diff_base_chunks: Option<(BufferId, BufferChunks<'a>)>,
     buffer_chunk: Option<Chunk<'a>>,
     range: Range<MultiBufferOffset>,
@@ -1055,7 +1128,7 @@ impl<'a, MBD: MultiBufferDimension> Dimension<'a, DiffTransformSummary> for Diff
 struct MultiBufferCursor<'a, MBD, BD> {
     excerpts: Cursor<'a, 'static, Excerpt, ExcerptDimension<MBD>>,
     diff_transforms: Cursor<'a, 'static, DiffTransform, DiffTransforms<MBD>>,
-    diffs: &'a TreeMap<BufferId, DiffStateSnapshot>,
+    diffs: &'a SumTree<DiffStateSnapshot>,
     cached_region: OnceCell<Option<MultiBufferRegion<'a, MBD, BD>>>,
 }
 
@@ -2321,15 +2394,12 @@ impl MultiBuffer {
         snapshot.excerpts = new_excerpts;
         for buffer_id in &removed_buffer_ids {
             self.diffs.remove(buffer_id);
-            snapshot.diffs.remove(buffer_id);
+            remove_diff_state(&mut snapshot.diffs, *buffer_id);
         }
 
-        // Recalculate has_inverted_diff after removing diffs
         if !removed_buffer_ids.is_empty() {
-            snapshot.has_inverted_diff = snapshot
-                .diffs
-                .iter()
-                .any(|(_, diff)| diff.main_buffer.is_some());
+            snapshot.has_inverted_diff =
+                snapshot.diffs.iter().any(|diff| diff.main_buffer.is_some());
         }
 
         if changed_trailing_excerpt {
@@ -2432,10 +2502,11 @@ impl MultiBuffer {
         let diff = diff.read(cx);
         let buffer_id = diff.buffer_id;
         let diff = DiffStateSnapshot {
+            buffer_id,
             diff: diff.snapshot(cx),
             main_buffer: None,
         };
-        self.snapshot.get_mut().diffs.insert(buffer_id, diff);
+        self.snapshot.get_mut().diffs.insert_or_replace(diff, ());
     }
 
     fn inverted_buffer_diff_language_changed(
@@ -2448,13 +2519,11 @@ impl MultiBuffer {
         let main_buffer_snapshot = main_buffer.read(cx).snapshot();
         let diff = diff.read(cx);
         let diff = DiffStateSnapshot {
+            buffer_id: base_text_buffer_id,
             diff: diff.snapshot(cx),
             main_buffer: Some(main_buffer_snapshot),
         };
-        self.snapshot
-            .get_mut()
-            .diffs
-            .insert(base_text_buffer_id, diff);
+        self.snapshot.get_mut().diffs.insert_or_replace(diff, ());
     }
 
     fn buffer_diff_changed(
@@ -2472,15 +2541,14 @@ impl MultiBuffer {
             return;
         };
         let new_diff = DiffStateSnapshot {
+            buffer_id,
             diff: diff.snapshot(cx),
             main_buffer: None,
         };
         let mut snapshot = self.snapshot.get_mut();
-        let base_text_changed = snapshot
-            .diffs
-            .get(&buffer_id)
+        let base_text_changed = find_diff_state(&snapshot.diffs, buffer_id)
             .is_none_or(|old_diff| !new_diff.base_texts_definitely_eq(old_diff));
-        snapshot.diffs.insert_or_replace(buffer_id, new_diff);
+        snapshot.diffs.insert_or_replace(new_diff, ());
 
         let buffer = buffer_state.buffer.read(cx);
         let diff_change_range = range.to_offset(buffer);
@@ -2519,13 +2587,12 @@ impl MultiBuffer {
         let main_buffer_snapshot = main_buffer.read(cx).snapshot();
         let diff = diff.read(cx);
         let new_diff = DiffStateSnapshot {
+            buffer_id: base_text_buffer_id,
             diff: diff.snapshot(cx),
             main_buffer: Some(main_buffer_snapshot),
         };
         let mut snapshot = self.snapshot.get_mut();
-        snapshot
-            .diffs
-            .insert_or_replace(base_text_buffer_id, new_diff);
+        snapshot.diffs.insert_or_replace(new_diff, ());
 
         let Some(diff_change_range) = diff_change_range else {
             return;
@@ -3141,11 +3208,7 @@ impl MultiBuffer {
         if !diffs.is_empty() {
             let mut diffs_to_add = Vec::new();
             for (id, diff) in diffs {
-                // For inverted diffs, we excerpt the diff base texts in the multibuffer
-                // and use the diff hunk base text ranges to compute diff transforms.
-                // Those base text ranges are usize, so make sure if the base text changed
-                // we also update the diff snapshot so that we don't use stale offsets
-                if buffer_diff.get(id).is_none_or(|existing_diff| {
+                if find_diff_state(buffer_diff, *id).is_none_or(|existing_diff| {
                     if existing_diff.main_buffer.is_none() {
                         return false;
                     }
@@ -3156,14 +3219,12 @@ impl MultiBuffer {
                             .changed_since(existing_diff.base_text().version())
                 }) {
                     if diffs_to_add.capacity() == 0 {
-                        // we'd rather overallocate than reallocate as buffer diffs are quite big
-                        // meaning re-allocations will be fairly expensive
                         diffs_to_add.reserve(diffs.len());
                     }
-                    diffs_to_add.push((*id, diff.snapshot(cx)));
+                    diffs_to_add.push(sum_tree::Edit::Insert(diff.snapshot(*id, cx)));
                 }
             }
-            buffer_diff.extend(diffs_to_add);
+            buffer_diff.edit(diffs_to_add, ());
         }
 
         let mut excerpts_to_edit = Vec::new();
@@ -3350,7 +3411,7 @@ impl MultiBuffer {
                             inserted_hunk_info: Some(hunk),
                             ..
                         }) => excerpts.item().is_some_and(|excerpt| {
-                            if let Some(diff) = snapshot.diffs.get(&excerpt.buffer_id)
+                            if let Some(diff) = find_diff_state(&snapshot.diffs, excerpt.buffer_id)
                                 && diff.main_buffer.is_some()
                             {
                                 return true;
@@ -3451,7 +3512,7 @@ impl MultiBuffer {
         while let Some(excerpt) = excerpts.item() {
             // Recompute the expanded hunks in the portion of the excerpt that
             // intersects the edit.
-            if let Some(diff) = snapshot.diffs.get(&excerpt.buffer_id) {
+            if let Some(diff) = find_diff_state(&snapshot.diffs, excerpt.buffer_id) {
                 let buffer = &excerpt.buffer;
                 let excerpt_start = *excerpts.start();
                 let excerpt_end = excerpt_start + excerpt.text_summary.len;
@@ -4031,7 +4092,7 @@ impl MultiBufferSnapshot {
     ) -> impl Iterator<Item = MultiBufferDiffHunk> + '_ {
         let query_range = range.start.to_point(self)..range.end.to_point(self);
         self.lift_buffer_metadata(query_range.clone(), move |buffer, buffer_range| {
-            let diff = self.diffs.get(&buffer.remote_id())?;
+            let diff = self.diff_state(buffer.remote_id())?;
             let iter = if let Some(main_buffer) = &diff.main_buffer {
                 let buffer_start = buffer.point_to_offset(buffer_range.start);
                 let buffer_end = buffer.point_to_offset(buffer_range.end);
@@ -4614,7 +4675,7 @@ impl MultiBufferSnapshot {
             .text_anchor
             .to_offset(&excerpt.buffer);
 
-        if let Some(diff) = self.diffs.get(&excerpt.buffer_id) {
+        if let Some(diff) = self.diff_state(excerpt.buffer_id) {
             if let Some(main_buffer) = &diff.main_buffer {
                 for hunk in diff
                     .hunks_intersecting_base_text_range_rev(excerpt_start..excerpt_end, main_buffer)
@@ -4649,7 +4710,7 @@ impl MultiBufferSnapshot {
             cursor.prev_excerpt();
             let excerpt = cursor.excerpt()?;
 
-            let Some(diff) = self.diffs.get(&excerpt.buffer_id) else {
+            let Some(diff) = self.diff_state(excerpt.buffer_id) else {
                 continue;
             };
             if let Some(main_buffer) = &diff.main_buffer {
@@ -4679,7 +4740,7 @@ impl MultiBufferSnapshot {
     }
 
     pub fn has_diff_hunks(&self) -> bool {
-        self.diffs.values().any(|diff| !diff.is_empty())
+        self.diffs.iter().any(|diff| !diff.is_empty())
     }
 
     pub fn is_inside_word<T: ToOffset>(
@@ -5262,7 +5323,8 @@ impl MultiBufferSnapshot {
             } => {
                 let buffer_start = base_text_byte_range.start + start_overshoot;
                 let mut buffer_end = base_text_byte_range.start + end_overshoot;
-                let Some(base_text) = self.diffs.get(buffer_id).map(|diff| diff.base_text()) else {
+                let Some(base_text) = self.diff_state(*buffer_id).map(|diff| diff.base_text())
+                else {
                     panic!("{:?} is in non-existent deleted hunk", range.start)
                 };
 
@@ -5314,7 +5376,8 @@ impl MultiBufferSnapshot {
                 ..
             } => {
                 let buffer_end = base_text_byte_range.start + overshoot;
-                let Some(base_text) = self.diffs.get(buffer_id).map(|diff| diff.base_text()) else {
+                let Some(base_text) = self.diff_state(*buffer_id).map(|diff| diff.base_text())
+                else {
                     panic!("{:?} is in non-existent deleted hunk", range.end)
                 };
 
@@ -5534,7 +5597,7 @@ impl MultiBufferSnapshot {
                 }) => {
                     if let Some(diff_base_anchor) = &anchor.diff_base_anchor
                         && let Some(base_text) =
-                            self.diffs.get(buffer_id).map(|diff| diff.base_text())
+                            self.diff_state(*buffer_id).map(|diff| diff.base_text())
                         && diff_base_anchor.is_valid(&base_text)
                     {
                         // The anchor carries a diff-base position — resolve it
@@ -5913,7 +5976,7 @@ impl MultiBufferSnapshot {
             ..
         }) = diff_transforms.item()
         {
-            let diff = self.diffs.get(buffer_id).expect("missing diff");
+            let diff = self.diff_state(*buffer_id).expect("missing diff");
             if offset_in_transform > base_text_byte_range.len() {
                 debug_assert!(*has_trailing_newline);
                 bias = Bias::Right;
@@ -7167,7 +7230,16 @@ impl MultiBufferSnapshot {
     }
 
     pub fn diff_for_buffer_id(&self, buffer_id: BufferId) -> Option<&BufferDiffSnapshot> {
-        self.diffs.get(&buffer_id).map(|diff| &diff.diff)
+        self.diff_state(buffer_id).map(|diff| &diff.diff)
+    }
+
+    fn diff_state(&self, buffer_id: BufferId) -> Option<&DiffStateSnapshot> {
+        find_diff_state(&self.diffs, buffer_id)
+    }
+
+    pub fn total_changed_lines(&self) -> (u32, u32) {
+        let summary = self.diffs.summary();
+        (summary.added_rows, summary.removed_rows)
     }
 
     pub fn all_diff_hunks_expanded(&self) -> bool {
@@ -7536,7 +7608,7 @@ where
                 hunk_info,
                 ..
             } => {
-                let diff = self.diffs.get(buffer_id)?;
+                let diff = find_diff_state(self.diffs, *buffer_id)?;
                 let buffer = diff.base_text();
                 let mut rope_cursor = buffer.as_rope().cursor(0);
                 let buffer_start = rope_cursor.summary::<BD>(base_text_byte_range.start);
@@ -8564,7 +8636,7 @@ impl<'a> Iterator for MultiBufferChunks<'a> {
                     }
                     chunks
                 } else {
-                    let base_buffer = &self.diffs.get(buffer_id)?.base_text();
+                    let base_buffer = &find_diff_state(self.diffs, *buffer_id)?.base_text();
                     base_buffer.chunks(base_text_start..base_text_end, self.language_aware)
                 };