git: Optimistically stage hunks when staging a file, take 2 (#45278)

Cole Miller and Ramon created

Relanding #43434 with an improved approach.

Release Notes:

- N/A

---------

Co-authored-by: Ramon <55579979+van-sprundel@users.noreply.github.com>

Change summary

crates/buffer_diff/src/buffer_diff.rs |  28 +++
crates/project/src/git_store.rs       | 229 ++++++++++++++++++----------
crates/project/src/project_tests.rs   | 143 ++++++++++++++++++
3 files changed, 321 insertions(+), 79 deletions(-)

Detailed changes

crates/buffer_diff/src/buffer_diff.rs 🔗

@@ -1159,6 +1159,34 @@ impl BufferDiff {
         new_index_text
     }
 
+    pub fn stage_or_unstage_all_hunks(
+        &mut self,
+        stage: bool,
+        buffer: &text::BufferSnapshot,
+        file_exists: bool,
+        cx: &mut Context<Self>,
+    ) {
+        let hunks = self
+            .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx)
+            .collect::<Vec<_>>();
+        let Some(secondary) = self.secondary_diff.as_ref() else {
+            return;
+        };
+        self.inner.stage_or_unstage_hunks_impl(
+            &secondary.read(cx).inner,
+            stage,
+            &hunks,
+            buffer,
+            file_exists,
+        );
+        if let Some((first, last)) = hunks.first().zip(hunks.last()) {
+            let changed_range = first.buffer_range.start..last.buffer_range.end;
+            cx.emit(BufferDiffEvent::DiffChanged {
+                changed_range: Some(changed_range),
+            });
+        }
+    }
+
     pub fn range_to_hunk_range(
         &self,
         range: Range<Anchor>,

crates/project/src/git_store.rs 🔗

@@ -4205,74 +4205,29 @@ impl Repository {
         entries: Vec<RepoPath>,
         cx: &mut Context<Self>,
     ) -> Task<anyhow::Result<()>> {
-        if entries.is_empty() {
-            return Task::ready(Ok(()));
-        }
-        let id = self.id;
-        let save_tasks = self.save_buffers(&entries, cx);
-        let paths = entries
-            .iter()
-            .map(|p| p.as_unix_str())
-            .collect::<Vec<_>>()
-            .join(" ");
-        let status = format!("git add {paths}");
-        let job_key = GitJobKey::WriteIndex(entries.clone());
-
-        self.spawn_job_with_tracking(
-            entries.clone(),
-            pending_op::GitStatus::Staged,
-            cx,
-            async move |this, cx| {
-                for save_task in save_tasks {
-                    save_task.await?;
-                }
-
-                this.update(cx, |this, _| {
-                    this.send_keyed_job(
-                        Some(job_key),
-                        Some(status.into()),
-                        move |git_repo, _cx| async move {
-                            match git_repo {
-                                RepositoryState::Local(LocalRepositoryState {
-                                    backend,
-                                    environment,
-                                    ..
-                                }) => backend.stage_paths(entries, environment.clone()).await,
-                                RepositoryState::Remote(RemoteRepositoryState {
-                                    project_id,
-                                    client,
-                                }) => {
-                                    client
-                                        .request(proto::Stage {
-                                            project_id: project_id.0,
-                                            repository_id: id.to_proto(),
-                                            paths: entries
-                                                .into_iter()
-                                                .map(|repo_path| repo_path.to_proto())
-                                                .collect(),
-                                        })
-                                        .await
-                                        .context("sending stage request")?;
-
-                                    Ok(())
-                                }
-                            }
-                        },
-                    )
-                })?
-                .await?
-            },
-        )
+        self.stage_or_unstage_entries(true, entries, cx)
     }
 
     pub fn unstage_entries(
         &mut self,
         entries: Vec<RepoPath>,
         cx: &mut Context<Self>,
+    ) -> Task<anyhow::Result<()>> {
+        self.stage_or_unstage_entries(false, entries, cx)
+    }
+
+    fn stage_or_unstage_entries(
+        &mut self,
+        stage: bool,
+        entries: Vec<RepoPath>,
+        cx: &mut Context<Self>,
     ) -> Task<anyhow::Result<()>> {
         if entries.is_empty() {
             return Task::ready(Ok(()));
         }
+        let Some(git_store) = self.git_store.upgrade() else {
+            return Task::ready(Ok(()));
+        };
         let id = self.id;
         let save_tasks = self.save_buffers(&entries, cx);
         let paths = entries
@@ -4280,48 +4235,164 @@ impl Repository {
             .map(|p| p.as_unix_str())
             .collect::<Vec<_>>()
             .join(" ");
-        let status = format!("git reset {paths}");
+        let status = if stage {
+            format!("git add {paths}")
+        } else {
+            format!("git reset {paths}")
+        };
         let job_key = GitJobKey::WriteIndex(entries.clone());
 
         self.spawn_job_with_tracking(
             entries.clone(),
-            pending_op::GitStatus::Unstaged,
+            if stage {
+                pending_op::GitStatus::Staged
+            } else {
+                pending_op::GitStatus::Unstaged
+            },
             cx,
             async move |this, cx| {
                 for save_task in save_tasks {
                     save_task.await?;
                 }
 
-                this.update(cx, |this, _| {
+                this.update(cx, |this, cx| {
+                    let weak_this = cx.weak_entity();
                     this.send_keyed_job(
                         Some(job_key),
                         Some(status.into()),
-                        move |git_repo, _cx| async move {
-                            match git_repo {
+                        move |git_repo, mut cx| async move {
+                            let hunk_staging_operation_counts = weak_this
+                                .update(&mut cx, |this, cx| {
+                                    let mut hunk_staging_operation_counts = HashMap::default();
+                                    for path in &entries {
+                                        let Some(project_path) =
+                                            this.repo_path_to_project_path(path, cx)
+                                        else {
+                                            continue;
+                                        };
+                                        let Some(buffer) = git_store
+                                            .read(cx)
+                                            .buffer_store
+                                            .read(cx)
+                                            .get_by_path(&project_path)
+                                        else {
+                                            continue;
+                                        };
+                                        let Some(diff_state) = git_store
+                                            .read(cx)
+                                            .diffs
+                                            .get(&buffer.read(cx).remote_id())
+                                            .cloned()
+                                        else {
+                                            continue;
+                                        };
+                                        let Some(uncommitted_diff) =
+                                            diff_state.read(cx).uncommitted_diff.as_ref().and_then(
+                                                |uncommitted_diff| uncommitted_diff.upgrade(),
+                                            )
+                                        else {
+                                            continue;
+                                        };
+                                        let buffer_snapshot = buffer.read(cx).text_snapshot();
+                                        let file_exists = buffer
+                                            .read(cx)
+                                            .file()
+                                            .is_some_and(|file| file.disk_state().exists());
+                                        let hunk_staging_operation_count =
+                                            diff_state.update(cx, |diff_state, cx| {
+                                                uncommitted_diff.update(
+                                                    cx,
+                                                    |uncommitted_diff, cx| {
+                                                        uncommitted_diff
+                                                            .stage_or_unstage_all_hunks(
+                                                                stage,
+                                                                &buffer_snapshot,
+                                                                file_exists,
+                                                                cx,
+                                                            );
+                                                    },
+                                                );
+
+                                                diff_state.hunk_staging_operation_count += 1;
+                                                diff_state.hunk_staging_operation_count
+                                            });
+                                        hunk_staging_operation_counts.insert(
+                                            diff_state.downgrade(),
+                                            hunk_staging_operation_count,
+                                        );
+                                    }
+                                    hunk_staging_operation_counts
+                                })
+                                .unwrap_or_default();
+
+                            let result = match git_repo {
                                 RepositoryState::Local(LocalRepositoryState {
                                     backend,
                                     environment,
                                     ..
-                                }) => backend.unstage_paths(entries, environment).await,
+                                }) => {
+                                    if stage {
+                                        backend.stage_paths(entries, environment.clone()).await
+                                    } else {
+                                        backend.unstage_paths(entries, environment.clone()).await
+                                    }
+                                }
                                 RepositoryState::Remote(RemoteRepositoryState {
                                     project_id,
                                     client,
                                 }) => {
-                                    client
-                                        .request(proto::Unstage {
-                                            project_id: project_id.0,
-                                            repository_id: id.to_proto(),
-                                            paths: entries
-                                                .into_iter()
-                                                .map(|repo_path| repo_path.to_proto())
-                                                .collect(),
-                                        })
-                                        .await
-                                        .context("sending unstage request")?;
-
-                                    Ok(())
+                                    if stage {
+                                        client
+                                            .request(proto::Stage {
+                                                project_id: project_id.0,
+                                                repository_id: id.to_proto(),
+                                                paths: entries
+                                                    .into_iter()
+                                                    .map(|repo_path| repo_path.to_proto())
+                                                    .collect(),
+                                            })
+                                            .await
+                                            .context("sending stage request")
+                                            .map(|_| ())
+                                    } else {
+                                        client
+                                            .request(proto::Unstage {
+                                                project_id: project_id.0,
+                                                repository_id: id.to_proto(),
+                                                paths: entries
+                                                    .into_iter()
+                                                    .map(|repo_path| repo_path.to_proto())
+                                                    .collect(),
+                                            })
+                                            .await
+                                            .context("sending unstage request")
+                                            .map(|_| ())
+                                    }
                                 }
+                            };
+
+                            for (diff_state, hunk_staging_operation_count) in
+                                hunk_staging_operation_counts
+                            {
+                                diff_state
+                                    .update(&mut cx, |diff_state, cx| {
+                                        if result.is_ok() {
+                                            diff_state.hunk_staging_operation_count_as_of_write =
+                                                hunk_staging_operation_count;
+                                        } else if let Some(uncommitted_diff) =
+                                            &diff_state.uncommitted_diff
+                                        {
+                                            uncommitted_diff
+                                                .update(cx, |uncommitted_diff, cx| {
+                                                    uncommitted_diff.clear_pending_hunks(cx);
+                                                })
+                                                .ok();
+                                        }
+                                    })
+                                    .ok();
                             }
+
+                            result
                         },
                     )
                 })?
@@ -4347,7 +4418,7 @@ impl Repository {
                 }
             })
             .collect();
-        self.stage_entries(to_stage, cx)
+        self.stage_or_unstage_entries(true, to_stage, cx)
     }
 
     pub fn unstage_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> {
@@ -4367,7 +4438,7 @@ impl Repository {
                 }
             })
             .collect();
-        self.unstage_entries(to_unstage, cx)
+        self.stage_or_unstage_entries(false, to_unstage, cx)
     }
 
     pub fn stash_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> {

crates/project/src/project_tests.rs 🔗

@@ -10922,3 +10922,146 @@ async fn test_git_worktree_remove(cx: &mut gpui::TestAppContext) {
     });
     assert!(active_repo_path.is_none());
 }
+
+#[gpui::test]
+async fn test_optimistic_hunks_in_staged_files(cx: &mut gpui::TestAppContext) {
+    use DiffHunkSecondaryStatus::*;
+    init_test(cx);
+
+    let committed_contents = r#"
+        one
+        two
+        three
+    "#
+    .unindent();
+    let file_contents = r#"
+        one
+        TWO
+        three
+    "#
+    .unindent();
+
+    let fs = FakeFs::new(cx.background_executor.clone());
+    fs.insert_tree(
+        path!("/dir"),
+        json!({
+            ".git": {},
+            "file.txt": file_contents.clone()
+        }),
+    )
+    .await;
+
+    fs.set_head_and_index_for_repo(
+        path!("/dir/.git").as_ref(),
+        &[("file.txt", committed_contents.clone())],
+    );
+
+    let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            project.open_local_buffer(path!("/dir/file.txt"), cx)
+        })
+        .await
+        .unwrap();
+    let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+    let uncommitted_diff = project
+        .update(cx, |project, cx| {
+            project.open_uncommitted_diff(buffer.clone(), cx)
+        })
+        .await
+        .unwrap();
+
+    // The hunk is initially unstaged.
+    uncommitted_diff.read_with(cx, |diff, cx| {
+        assert_hunks(
+            diff.hunks(&snapshot, cx),
+            &snapshot,
+            &diff.base_text_string().unwrap(),
+            &[(
+                1..2,
+                "two\n",
+                "TWO\n",
+                DiffHunkStatus::modified(HasSecondaryHunk),
+            )],
+        );
+    });
+
+    // Get the repository handle.
+    let repo = project.read_with(cx, |project, cx| {
+        project.repositories(cx).values().next().unwrap().clone()
+    });
+
+    // Stage the file.
+    let stage_task = repo.update(cx, |repo, cx| {
+        repo.stage_entries(vec![repo_path("file.txt")], cx)
+    });
+
+    // Run a few ticks to let the job start and mark hunks as pending,
+    // but don't run_until_parked which would complete the entire operation.
+    for _ in 0..10 {
+        cx.executor().tick();
+        let [hunk]: [_; 1] = uncommitted_diff
+            .read_with(cx, |diff, cx| diff.hunks(&snapshot, cx).collect::<Vec<_>>())
+            .try_into()
+            .unwrap();
+        match hunk.secondary_status {
+            HasSecondaryHunk => {}
+            SecondaryHunkRemovalPending => break,
+            NoSecondaryHunk => panic!("hunk was not optimistically staged"),
+            _ => panic!("unexpected hunk state"),
+        }
+    }
+    uncommitted_diff.read_with(cx, |diff, cx| {
+        assert_hunks(
+            diff.hunks(&snapshot, cx),
+            &snapshot,
+            &diff.base_text_string().unwrap(),
+            &[(
+                1..2,
+                "two\n",
+                "TWO\n",
+                DiffHunkStatus::modified(SecondaryHunkRemovalPending),
+            )],
+        );
+    });
+
+    // Let the staging complete.
+    stage_task.await.unwrap();
+    cx.run_until_parked();
+
+    // The hunk is now fully staged.
+    uncommitted_diff.read_with(cx, |diff, cx| {
+        assert_hunks(
+            diff.hunks(&snapshot, cx),
+            &snapshot,
+            &diff.base_text_string().unwrap(),
+            &[(
+                1..2,
+                "two\n",
+                "TWO\n",
+                DiffHunkStatus::modified(NoSecondaryHunk),
+            )],
+        );
+    });
+
+    // Simulate a commit by updating HEAD to match the current file contents.
+    // The FakeGitRepository's commit method is a no-op, so we need to manually
+    // update HEAD to simulate the commit completing.
+    fs.set_head_for_repo(
+        path!("/dir/.git").as_ref(),
+        &[("file.txt", file_contents.clone())],
+        "newhead",
+    );
+    cx.run_until_parked();
+
+    // After committing, there are no more hunks.
+    uncommitted_diff.read_with(cx, |diff, cx| {
+        assert_hunks(
+            diff.hunks(&snapshot, cx),
+            &snapshot,
+            &diff.base_text_string().unwrap(),
+            &[] as &[(Range<u32>, &str, &str, DiffHunkStatus)],
+        );
+    });
+}