Avoid polluting branch list and restore parent commit when using checkpoints (#27191)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs |   2 
crates/fs/src/fake_git_repo.rs         |  12 
crates/git/src/repository.rs           | 275 +++++++++++++++++++++++++--
crates/project/src/git.rs              |  23 -
4 files changed, 268 insertions(+), 44 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -791,7 +791,7 @@ impl ActiveThread {
             .when_some(checkpoint, |parent, checkpoint| {
                 parent.child(
                     h_flex().pl_2().child(
-                        Button::new("restore-checkpoint", "Restore Checkpoint")
+                        Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
                             .icon(IconName::Undo)
                             .size(ButtonSize::Compact)
                             .on_click(cx.listener(move |this, _, _window, cx| {

crates/fs/src/fake_git_repo.rs 🔗

@@ -5,8 +5,8 @@ use futures::future::{self, BoxFuture};
 use git::{
     blame::Blame,
     repository::{
-        AskPassSession, Branch, CommitDetails, GitRepository, PushOptions, Remote, RepoPath,
-        ResetMode,
+        AskPassSession, Branch, CommitDetails, GitRepository, GitRepositoryCheckpoint, PushOptions,
+        Remote, RepoPath, ResetMode,
     },
     status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus},
 };
@@ -409,11 +409,15 @@ impl GitRepository for FakeGitRepository {
         unimplemented!()
     }
 
-    fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture<Result<git::Oid>> {
+    fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture<Result<GitRepositoryCheckpoint>> {
         unimplemented!()
     }
 
-    fn restore_checkpoint(&self, _oid: git::Oid, _cx: AsyncApp) -> BoxFuture<Result<()>> {
+    fn restore_checkpoint(
+        &self,
+        _checkpoint: GitRepositoryCheckpoint,
+        _cx: AsyncApp,
+    ) -> BoxFuture<Result<()>> {
         unimplemented!()
     }
 }

crates/git/src/repository.rs 🔗

@@ -290,10 +290,14 @@ pub trait GitRepository: Send + Sync {
     fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture<Result<String>>;
 
     /// Creates a checkpoint for the repository.
-    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>>;
+    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<GitRepositoryCheckpoint>>;
 
     /// Resets to a previously-created checkpoint.
-    fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>>;
+    fn restore_checkpoint(
+        &self,
+        checkpoint: GitRepositoryCheckpoint,
+        cx: AsyncApp,
+    ) -> BoxFuture<Result<()>>;
 }
 
 pub enum DiffType {
@@ -337,6 +341,12 @@ impl RealGitRepository {
     }
 }
 
+#[derive(Copy, Clone)]
+pub struct GitRepositoryCheckpoint {
+    head_sha: Option<Oid>,
+    sha: Oid,
+}
+
 // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects
 const GIT_MODE_SYMLINK: u32 = 0o120000;
 
@@ -1033,7 +1043,7 @@ impl GitRepository for RealGitRepository {
         .boxed()
     }
 
-    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>> {
+    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<GitRepositoryCheckpoint>> {
         let working_directory = self.working_directory();
         let git_binary_path = self.git_binary_path.clone();
         let executor = cx.background_executor().clone();
@@ -1056,10 +1066,7 @@ impl GitRepository for RealGitRepository {
                 let output = new_smol_command(&git_binary_path)
                     .current_dir(&working_directory)
                     .env("GIT_INDEX_FILE", &index_file_path)
-                    .env("GIT_AUTHOR_NAME", "Zed")
-                    .env("GIT_AUTHOR_EMAIL", "hi@zed.dev")
-                    .env("GIT_COMMITTER_NAME", "Zed")
-                    .env("GIT_COMMITTER_EMAIL", "hi@zed.dev")
+                    .envs(checkpoint_author_envs())
                     .args(args)
                     .output()
                     .await?;
@@ -1071,35 +1078,56 @@ impl GitRepository for RealGitRepository {
                 }
             };
 
+            let head_sha = run_git_command(&["rev-parse", "HEAD"]).await.ok();
             run_git_command(&["add", "--all"]).await?;
             let tree = run_git_command(&["write-tree"]).await?;
-            let commit_sha = run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?;
+            let checkpoint_sha = if let Some(head_sha) = head_sha.as_deref() {
+                run_git_command(&["commit-tree", &tree, "-p", head_sha, "-m", "Checkpoint"]).await?
+            } else {
+                run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?
+            };
             let ref_name = Uuid::new_v4().to_string();
-            run_git_command(&["update-ref", &format!("refs/heads/{ref_name}"), &commit_sha])
-                .await?;
+            run_git_command(&[
+                "update-ref",
+                &format!("refs/zed/{ref_name}"),
+                &checkpoint_sha,
+            ])
+            .await?;
 
             smol::fs::remove_file(index_file_path).await.ok();
             delete_temp_index.abort();
 
-            commit_sha.parse()
+            Ok(GitRepositoryCheckpoint {
+                head_sha: if let Some(head_sha) = head_sha {
+                    Some(head_sha.parse()?)
+                } else {
+                    None
+                },
+                sha: checkpoint_sha.parse()?,
+            })
         })
         .boxed()
     }
 
-    fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>> {
+    fn restore_checkpoint(
+        &self,
+        checkpoint: GitRepositoryCheckpoint,
+        cx: AsyncApp,
+    ) -> BoxFuture<Result<()>> {
         let working_directory = self.working_directory();
         let git_binary_path = self.git_binary_path.clone();
         cx.background_spawn(async move {
             let working_directory = working_directory?;
             let index_file_path = working_directory.join(".git/index.tmp");
 
-            let run_git_command = async |args: &[&str]| {
-                let output = new_smol_command(&git_binary_path)
-                    .current_dir(&working_directory)
-                    .env("GIT_INDEX_FILE", &index_file_path)
-                    .args(args)
-                    .output()
-                    .await?;
+            let run_git_command = async |args: &[&str], use_temp_index: bool| {
+                let mut command = new_smol_command(&git_binary_path);
+                command.current_dir(&working_directory);
+                command.args(args);
+                if use_temp_index {
+                    command.env("GIT_INDEX_FILE", &index_file_path);
+                }
+                let output = command.output().await?;
                 if output.status.success() {
                     anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
                 } else {
@@ -1108,9 +1136,26 @@ impl GitRepository for RealGitRepository {
                 }
             };
 
-            run_git_command(&["restore", "--source", &oid.to_string(), "--worktree", "."]).await?;
-            run_git_command(&["read-tree", &oid.to_string()]).await?;
-            run_git_command(&["clean", "-d", "--force"]).await?;
+            run_git_command(
+                &[
+                    "restore",
+                    "--source",
+                    &checkpoint.sha.to_string(),
+                    "--worktree",
+                    ".",
+                ],
+                false,
+            )
+            .await?;
+            run_git_command(&["read-tree", &checkpoint.sha.to_string()], true).await?;
+            run_git_command(&["clean", "-d", "--force"], true).await?;
+
+            if let Some(head_sha) = checkpoint.head_sha {
+                run_git_command(&["reset", "--mixed", &head_sha.to_string()], false).await?;
+            } else {
+                run_git_command(&["update-ref", "-d", "HEAD"], false).await?;
+            }
+
             Ok(())
         })
         .boxed()
@@ -1350,14 +1395,111 @@ fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
     }
 }
 
+fn checkpoint_author_envs() -> HashMap<String, String> {
+    HashMap::from_iter([
+        ("GIT_AUTHOR_NAME".to_string(), "Zed".to_string()),
+        ("GIT_AUTHOR_EMAIL".to_string(), "hi@zed.dev".to_string()),
+        ("GIT_COMMITTER_NAME".to_string(), "Zed".to_string()),
+        ("GIT_COMMITTER_EMAIL".to_string(), "hi@zed.dev".to_string()),
+    ])
+}
+
 #[cfg(test)]
 mod tests {
+    use super::*;
+    use crate::status::FileStatus;
     use gpui::TestAppContext;
 
-    use super::*;
+    #[gpui::test]
+    async fn test_checkpoint_basic(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let repo_dir = tempfile::tempdir().unwrap();
+
+        git2::Repository::init(repo_dir.path()).unwrap();
+        let file_path = repo_dir.path().join("file");
+        smol::fs::write(&file_path, "initial").await.unwrap();
+
+        let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap();
+        repo.stage_paths(
+            vec![RepoPath::from_str("file")],
+            HashMap::default(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+        repo.commit(
+            "Initial commit".into(),
+            None,
+            checkpoint_author_envs(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+
+        smol::fs::write(&file_path, "modified before checkpoint")
+            .await
+            .unwrap();
+        smol::fs::write(repo_dir.path().join("new_file_before_checkpoint"), "1")
+            .await
+            .unwrap();
+        let sha_before_checkpoint = repo.head_sha().unwrap();
+        let checkpoint = repo.checkpoint(cx.to_async()).await.unwrap();
+
+        // Ensure the user can't see any branches after creating a checkpoint.
+        assert_eq!(repo.branches().await.unwrap().len(), 1);
+
+        smol::fs::write(&file_path, "modified after checkpoint")
+            .await
+            .unwrap();
+        repo.stage_paths(
+            vec![RepoPath::from_str("file")],
+            HashMap::default(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+        repo.commit(
+            "Commit after checkpoint".into(),
+            None,
+            checkpoint_author_envs(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+
+        smol::fs::remove_file(repo_dir.path().join("new_file_before_checkpoint"))
+            .await
+            .unwrap();
+        smol::fs::write(repo_dir.path().join("new_file_after_checkpoint"), "2")
+            .await
+            .unwrap();
+
+        repo.restore_checkpoint(checkpoint, cx.to_async())
+            .await
+            .unwrap();
+
+        assert_eq!(repo.head_sha().unwrap(), sha_before_checkpoint);
+        assert_eq!(
+            smol::fs::read_to_string(&file_path).await.unwrap(),
+            "modified before checkpoint"
+        );
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("new_file_before_checkpoint"))
+                .await
+                .unwrap(),
+            "1"
+        );
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("new_file_after_checkpoint"))
+                .await
+                .ok(),
+            None
+        );
+    }
 
     #[gpui::test]
-    async fn test_checkpoint(cx: &mut TestAppContext) {
+    async fn test_checkpoint_empty_repo(cx: &mut TestAppContext) {
         cx.executor().allow_parking();
 
         let repo_dir = tempfile::tempdir().unwrap();
@@ -1369,6 +1511,9 @@ mod tests {
             .unwrap();
         let checkpoint_sha = repo.checkpoint(cx.to_async()).await.unwrap();
 
+        // Ensure the user can't see any branches after creating a checkpoint.
+        assert_eq!(repo.branches().await.unwrap().len(), 1);
+
         smol::fs::write(repo_dir.path().join("foo"), "bar")
             .await
             .unwrap();
@@ -1392,6 +1537,88 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_undoing_commit_via_checkpoint(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let repo_dir = tempfile::tempdir().unwrap();
+
+        git2::Repository::init(repo_dir.path()).unwrap();
+        let file_path = repo_dir.path().join("file");
+        smol::fs::write(&file_path, "initial").await.unwrap();
+
+        let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap();
+        repo.stage_paths(
+            vec![RepoPath::from_str("file")],
+            HashMap::default(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+        repo.commit(
+            "Initial commit".into(),
+            None,
+            checkpoint_author_envs(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+
+        let initial_commit_sha = repo.head_sha().unwrap();
+
+        smol::fs::write(repo_dir.path().join("new_file1"), "content1")
+            .await
+            .unwrap();
+        smol::fs::write(repo_dir.path().join("new_file2"), "content2")
+            .await
+            .unwrap();
+
+        let checkpoint = repo.checkpoint(cx.to_async()).await.unwrap();
+
+        repo.stage_paths(
+            vec![
+                RepoPath::from_str("new_file1"),
+                RepoPath::from_str("new_file2"),
+            ],
+            HashMap::default(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+        repo.commit(
+            "Commit new files".into(),
+            None,
+            checkpoint_author_envs(),
+            cx.to_async(),
+        )
+        .await
+        .unwrap();
+
+        repo.restore_checkpoint(checkpoint, cx.to_async())
+            .await
+            .unwrap();
+        assert_eq!(repo.head_sha().unwrap(), initial_commit_sha);
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("new_file1"))
+                .await
+                .unwrap(),
+            "content1"
+        );
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("new_file2"))
+                .await
+                .unwrap(),
+            "content2"
+        );
+        assert_eq!(
+            repo.status(&[]).unwrap().entries.as_ref(),
+            &[
+                (RepoPath::from_str("new_file1"), FileStatus::Untracked),
+                (RepoPath::from_str("new_file2"), FileStatus::Untracked)
+            ]
+        );
+    }
+
     #[test]
     fn test_branches_parsing() {
         // suppress "help: octal escapes are not supported, `\0` is always null"

crates/project/src/git.rs 🔗

@@ -14,7 +14,7 @@ use futures::{
     future::{self, OptionFuture, Shared},
     FutureExt as _, StreamExt as _,
 };
-use git::{repository::DiffType, Oid};
+use git::repository::{DiffType, GitRepositoryCheckpoint};
 use git::{
     repository::{
         Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath,
@@ -119,12 +119,7 @@ enum GitStoreState {
 
 #[derive(Clone)]
 pub struct GitStoreCheckpoint {
-    checkpoints_by_dot_git_abs_path: HashMap<PathBuf, RepositoryCheckpoint>,
-}
-
-#[derive(Copy, Clone)]
-pub struct RepositoryCheckpoint {
-    sha: Oid,
+    checkpoints_by_dot_git_abs_path: HashMap<PathBuf, GitRepositoryCheckpoint>,
 }
 
 pub struct Repository {
@@ -526,7 +521,8 @@ impl GitStore {
         }
 
         cx.background_executor().spawn(async move {
-            let checkpoints: Vec<RepositoryCheckpoint> = future::try_join_all(checkpoints).await?;
+            let checkpoints: Vec<GitRepositoryCheckpoint> =
+                future::try_join_all(checkpoints).await?;
             Ok(GitStoreCheckpoint {
                 checkpoints_by_dot_git_abs_path: dot_git_abs_paths
                     .into_iter()
@@ -2972,13 +2968,10 @@ impl Repository {
         })
     }
 
-    pub fn checkpoint(&self) -> oneshot::Receiver<Result<RepositoryCheckpoint>> {
+    pub fn checkpoint(&self) -> oneshot::Receiver<Result<GitRepositoryCheckpoint>> {
         self.send_job(|repo, cx| async move {
             match repo {
-                GitRepo::Local(git_repository) => {
-                    let sha = git_repository.checkpoint(cx).await?;
-                    Ok(RepositoryCheckpoint { sha })
-                }
+                GitRepo::Local(git_repository) => git_repository.checkpoint(cx).await,
                 GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
             }
         })
@@ -2986,12 +2979,12 @@ impl Repository {
 
     pub fn restore_checkpoint(
         &self,
-        checkpoint: RepositoryCheckpoint,
+        checkpoint: GitRepositoryCheckpoint,
     ) -> oneshot::Receiver<Result<()>> {
         self.send_job(move |repo, cx| async move {
             match repo {
                 GitRepo::Local(git_repository) => {
-                    git_repository.restore_checkpoint(checkpoint.sha, cx).await
+                    git_repository.restore_checkpoint(checkpoint, cx).await
                 }
                 GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
             }