Git: Fix hunks being skipped when staging too quickly (#27552)

João Marcos and Max Brunsfeld created

Release Notes:

- Git: Fix hunks being skipped when staging too quickly.

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/buffer_diff/src/buffer_diff.rs | 78 ++++++++++++----------------
crates/project/src/git_store.rs       | 58 ++++++++++++++++++---
2 files changed, 83 insertions(+), 53 deletions(-)

Detailed changes

crates/buffer_diff/src/buffer_diff.rs 🔗

@@ -187,13 +187,13 @@ impl BufferDiffSnapshot {
 impl BufferDiffInner {
     /// Returns the new index text and new pending hunks.
     fn stage_or_unstage_hunks_impl(
-        &self,
+        &mut self,
         unstaged_diff: &Self,
         stage: bool,
         hunks: &[DiffHunk],
         buffer: &text::BufferSnapshot,
         file_exists: bool,
-    ) -> (Option<Rope>, SumTree<PendingHunk>) {
+    ) -> Option<Rope> {
         let head_text = self
             .base_text_exists
             .then(|| self.base_text.as_rope().clone());
@@ -206,7 +206,7 @@ impl BufferDiffInner {
         let (index_text, head_text) = match (index_text, head_text) {
             (Some(index_text), Some(head_text)) if file_exists || !stage => (index_text, head_text),
             (index_text, head_text) => {
-                let (rope, new_status) = if stage {
+                let (new_index_text, new_status) = if stage {
                     log::debug!("stage all");
                     (
                         file_exists.then(|| buffer.as_rope().clone()),
@@ -226,15 +226,13 @@ impl BufferDiffInner {
                     buffer_version: buffer.version().clone(),
                     new_status,
                 };
-                let tree = SumTree::from_item(hunk, buffer);
-                return (rope, tree);
+                self.pending_hunks = SumTree::from_item(hunk, buffer);
+                return new_index_text;
             }
         };
 
         let mut pending_hunks = SumTree::new(buffer);
-        let mut old_pending_hunks = unstaged_diff
-            .pending_hunks
-            .cursor::<DiffHunkSummary>(buffer);
+        let mut old_pending_hunks = self.pending_hunks.cursor::<DiffHunkSummary>(buffer);
 
         // first, merge new hunks into pending_hunks
         for DiffHunk {
@@ -366,6 +364,8 @@ impl BufferDiffInner {
             edits.push((index_byte_range, replacement_text));
         }
         drop(pending_hunks_iter);
+        drop(old_pending_hunks);
+        self.pending_hunks = pending_hunks;
 
         #[cfg(debug_assertions)] // invariants: non-overlapping and sorted
         {
@@ -384,7 +384,7 @@ impl BufferDiffInner {
             new_index_text.push(&replacement_text);
         }
         new_index_text.append(index_cursor.suffix());
-        (Some(new_index_text), pending_hunks)
+        Some(new_index_text)
     }
 
     fn hunks_intersecting_range<'a>(
@@ -421,15 +421,14 @@ impl BufferDiffInner {
             ]
         });
 
+        let mut pending_hunks_cursor = self.pending_hunks.cursor::<DiffHunkSummary>(buffer);
+        pending_hunks_cursor.next(buffer);
+
         let mut secondary_cursor = None;
-        let mut pending_hunks_cursor = None;
         if let Some(secondary) = secondary.as_ref() {
             let mut cursor = secondary.hunks.cursor::<DiffHunkSummary>(buffer);
             cursor.next(buffer);
             secondary_cursor = Some(cursor);
-            let mut cursor = secondary.pending_hunks.cursor::<DiffHunkSummary>(buffer);
-            cursor.next(buffer);
-            pending_hunks_cursor = Some(cursor);
         }
 
         let max_point = buffer.max_point();
@@ -451,29 +450,27 @@ impl BufferDiffInner {
             let mut secondary_status = DiffHunkSecondaryStatus::NoSecondaryHunk;
 
             let mut has_pending = false;
-            if let Some(pending_cursor) = pending_hunks_cursor.as_mut() {
-                if start_anchor
-                    .cmp(&pending_cursor.start().buffer_range.start, buffer)
-                    .is_gt()
-                {
-                    pending_cursor.seek_forward(&start_anchor, Bias::Left, buffer);
-                }
+            if start_anchor
+                .cmp(&pending_hunks_cursor.start().buffer_range.start, buffer)
+                .is_gt()
+            {
+                pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left, buffer);
+            }
 
-                if let Some(pending_hunk) = pending_cursor.item() {
-                    let mut pending_range = pending_hunk.buffer_range.to_point(buffer);
-                    if pending_range.end.column > 0 {
-                        pending_range.end.row += 1;
-                        pending_range.end.column = 0;
-                    }
+            if let Some(pending_hunk) = pending_hunks_cursor.item() {
+                let mut pending_range = pending_hunk.buffer_range.to_point(buffer);
+                if pending_range.end.column > 0 {
+                    pending_range.end.row += 1;
+                    pending_range.end.column = 0;
+                }
 
-                    if pending_range == (start_point..end_point) {
-                        if !buffer.has_edits_since_in_range(
-                            &pending_hunk.buffer_version,
-                            start_anchor..end_anchor,
-                        ) {
-                            has_pending = true;
-                            secondary_status = pending_hunk.new_status;
-                        }
+                if pending_range == (start_point..end_point) {
+                    if !buffer.has_edits_since_in_range(
+                        &pending_hunk.buffer_version,
+                        start_anchor..end_anchor,
+                    ) {
+                        has_pending = true;
+                        secondary_status = pending_hunk.new_status;
                     }
                 }
             }
@@ -852,10 +849,8 @@ impl BufferDiff {
     }
 
     pub fn clear_pending_hunks(&mut self, cx: &mut Context<Self>) {
-        if let Some(secondary_diff) = &self.secondary_diff {
-            secondary_diff.update(cx, |diff, _| {
-                diff.inner.pending_hunks = SumTree::from_summary(DiffHunkSummary::default());
-            });
+        if self.secondary_diff.is_some() {
+            self.inner.pending_hunks = SumTree::from_summary(DiffHunkSummary::default());
             cx.emit(BufferDiffEvent::DiffChanged {
                 changed_range: Some(Anchor::MIN..Anchor::MAX),
             });
@@ -870,7 +865,7 @@ impl BufferDiff {
         file_exists: bool,
         cx: &mut Context<Self>,
     ) -> Option<Rope> {
-        let (new_index_text, new_pending_hunks) = self.inner.stage_or_unstage_hunks_impl(
+        let new_index_text = self.inner.stage_or_unstage_hunks_impl(
             &self.secondary_diff.as_ref()?.read(cx).inner,
             stage,
             &hunks,
@@ -878,11 +873,6 @@ impl BufferDiff {
             file_exists,
         );
 
-        if let Some(unstaged_diff) = &self.secondary_diff {
-            unstaged_diff.update(cx, |diff, _| {
-                diff.inner.pending_hunks = new_pending_hunks;
-            });
-        }
         cx.emit(BufferDiffEvent::HunksStagedOrUnstaged(
             new_index_text.clone(),
         ));

crates/project/src/git_store.rs 🔗

@@ -85,6 +85,7 @@ struct BufferDiffState {
     language: Option<Arc<Language>>,
     language_registry: Option<Arc<LanguageRegistry>>,
     diff_updated_futures: Vec<oneshot::Sender<()>>,
+    hunk_staging_operation_count: usize,
 
     head_text: Option<Arc<String>>,
     index_text: Option<Arc<String>>,
@@ -574,7 +575,7 @@ impl GitStore {
                     }
                 }
 
-                let rx = diff_state.diff_bases_changed(text_snapshot, diff_bases_change, cx);
+                let rx = diff_state.diff_bases_changed(text_snapshot, diff_bases_change, 0, cx);
 
                 anyhow::Ok(async move {
                     rx.await.ok();
@@ -1140,7 +1141,11 @@ impl GitStore {
             if let Some(diff_state) = self.diffs.get_mut(&buffer.read(cx).remote_id()) {
                 let buffer = buffer.read(cx).text_snapshot();
                 futures.push(diff_state.update(cx, |diff_state, cx| {
-                    diff_state.recalculate_diffs(buffer, cx)
+                    diff_state.recalculate_diffs(
+                        buffer,
+                        diff_state.hunk_staging_operation_count,
+                        cx,
+                    )
                 }));
             }
         }
@@ -1157,6 +1162,11 @@ impl GitStore {
     ) {
         if let BufferDiffEvent::HunksStagedOrUnstaged(new_index_text) = event {
             let buffer_id = diff.read(cx).buffer_id;
+            if let Some(diff_state) = self.diffs.get(&buffer_id) {
+                diff_state.update(cx, |diff_state, _| {
+                    diff_state.hunk_staging_operation_count += 1;
+                });
+            }
             if let Some((repo, path)) = self.repository_and_path_for_buffer_id(buffer_id, cx) {
                 let recv = repo.update(cx, |repo, cx| {
                     log::debug!("updating index text for buffer {}", path.display());
@@ -1229,6 +1239,7 @@ impl GitStore {
                 file.path.clone(),
                 has_unstaged_diff.then(|| diff_state.index_text.clone()),
                 has_uncommitted_diff.then(|| diff_state.head_text.clone()),
+                diff_state.hunk_staging_operation_count,
             );
             diff_state_updates.entry(repo_id).or_default().push(update);
         }
@@ -1252,8 +1263,13 @@ impl GitStore {
                     };
 
                     let mut diff_bases_changes_by_buffer = Vec::new();
-                    for (buffer, path, current_index_text, current_head_text) in
-                        &repo_diff_state_updates
+                    for (
+                        buffer,
+                        path,
+                        current_index_text,
+                        current_head_text,
+                        hunk_staging_operation_count,
+                    ) in &repo_diff_state_updates
                     {
                         let Some(local_repo) = snapshot.local_repo_containing_path(&path) else {
                             continue;
@@ -1306,12 +1322,18 @@ impl GitStore {
                                 (false, false) => None,
                             };
 
-                        diff_bases_changes_by_buffer.push((buffer, diff_bases_change))
+                        diff_bases_changes_by_buffer.push((
+                            buffer,
+                            diff_bases_change,
+                            *hunk_staging_operation_count,
+                        ))
                     }
 
                     git_store
                         .update(&mut cx, |git_store, cx| {
-                            for (buffer, diff_bases_change) in diff_bases_changes_by_buffer {
+                            for (buffer, diff_bases_change, hunk_staging_operation_count) in
+                                diff_bases_changes_by_buffer
+                            {
                                 let Some(diff_state) =
                                     git_store.diffs.get(&buffer.read(cx).remote_id())
                                 else {
@@ -1356,6 +1378,7 @@ impl GitStore {
                                     let _ = diff_state.diff_bases_changed(
                                         buffer.text_snapshot(),
                                         diff_bases_change,
+                                        hunk_staging_operation_count,
                                         cx,
                                     );
                                 });
@@ -2192,7 +2215,11 @@ impl BufferDiffState {
     fn buffer_language_changed(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         self.language = buffer.read(cx).language().cloned();
         self.language_changed = true;
-        let _ = self.recalculate_diffs(buffer.read(cx).text_snapshot(), cx);
+        let _ = self.recalculate_diffs(
+            buffer.read(cx).text_snapshot(),
+            self.hunk_staging_operation_count,
+            cx,
+        );
     }
 
     fn unstaged_diff(&self) -> Option<Entity<BufferDiff>> {
@@ -2225,7 +2252,12 @@ impl BufferDiffState {
             },
         };
 
-        let _ = self.diff_bases_changed(buffer, diff_bases_change, cx);
+        let _ = self.diff_bases_changed(
+            buffer,
+            diff_bases_change,
+            self.hunk_staging_operation_count,
+            cx,
+        );
     }
 
     pub fn wait_for_recalculation(&mut self) -> Option<oneshot::Receiver<()>> {
@@ -2241,6 +2273,7 @@ impl BufferDiffState {
         &mut self,
         buffer: text::BufferSnapshot,
         diff_bases_change: DiffBasesChange,
+        prev_hunk_staging_operation_count: usize,
         cx: &mut Context<Self>,
     ) -> oneshot::Receiver<()> {
         match diff_bases_change {
@@ -2282,12 +2315,13 @@ impl BufferDiffState {
             }
         }
 
-        self.recalculate_diffs(buffer, cx)
+        self.recalculate_diffs(buffer, prev_hunk_staging_operation_count, cx)
     }
 
     fn recalculate_diffs(
         &mut self,
         buffer: text::BufferSnapshot,
+        prev_hunk_staging_operation_count: usize,
         cx: &mut Context<Self>,
     ) -> oneshot::Receiver<()> {
         log::debug!("recalculate diffs");
@@ -2347,6 +2381,12 @@ impl BufferDiffState {
                 }
             }
 
+            if this.update(cx, |this, _| {
+                this.hunk_staging_operation_count > prev_hunk_staging_operation_count
+            })? {
+                return Ok(());
+            }
+
             let unstaged_changed_range = if let Some((unstaged_diff, new_unstaged_diff)) =
                 unstaged_diff.as_ref().zip(new_unstaged_diff.clone())
             {