Start on a Git-based review flow (#27103)

Antonio Scandurra and Nathan Sobo created

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                              |   3 
crates/assistant2/Cargo.toml            |   1 
crates/assistant2/src/active_thread.rs  |  21 ++
crates/assistant2/src/message_editor.rs | 163 ++++++++++--------------
crates/assistant2/src/thread.rs         |  65 ++++++++-
crates/assistant_eval/src/eval.rs       |   2 
crates/fs/src/fake_git_repo.rs          |   8 +
crates/git/Cargo.toml                   |   2 
crates/git/src/repository.rs            | 183 +++++++++++++++++++++++---
crates/project/src/git.rs               |  79 +++++++++++
10 files changed, 396 insertions(+), 131 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -467,6 +467,7 @@ dependencies = [
  "futures 0.3.31",
  "fuzzy",
  "git",
+ "git_ui",
  "gpui",
  "heed",
  "html_to_markdown",
@@ -5601,11 +5602,13 @@ dependencies = [
  "serde_json",
  "smol",
  "sum_tree",
+ "tempfile",
  "text",
  "time",
  "unindent",
  "url",
  "util",
+ "uuid",
 ]
 
 [[package]]

crates/assistant2/Cargo.toml 🔗

@@ -39,6 +39,7 @@ fs.workspace = true
 futures.workspace = true
 fuzzy.workspace = true
 git.workspace = true
+git_ui.workspace = true
 gpui.workspace = true
 heed.workspace = true
 html_to_markdown.workspace = true

crates/assistant2/src/active_thread.rs 🔗

@@ -550,6 +550,7 @@ impl ActiveThread {
 
         let thread = self.thread.read(cx);
         // Get all the data we need from thread before we start using it in closures
+        let checkpoint = thread.checkpoint_for_message(message_id);
         let context = thread.context_for_message(message_id);
         let tool_uses = thread.tool_uses_for_message(message_id);
         let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
@@ -734,7 +735,25 @@ impl ActiveThread {
             ),
         };
 
-        styled_message.into_any()
+        v_flex()
+            .when_some(checkpoint, |parent, checkpoint| {
+                parent.child(
+                    h_flex().pl_2().child(
+                        Button::new("restore-checkpoint", "Restore Checkpoint")
+                            .icon(IconName::Undo)
+                            .size(ButtonSize::Compact)
+                            .on_click(cx.listener(move |this, _, _window, cx| {
+                                this.thread.update(cx, |thread, cx| {
+                                    thread
+                                        .restore_checkpoint(checkpoint.clone(), cx)
+                                        .detach_and_log_err(cx);
+                                });
+                            })),
+                    ),
+                )
+            })
+            .child(styled_message)
+            .into_any()
     }
 
     fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {

crates/assistant2/src/message_editor.rs 🔗

@@ -3,23 +3,25 @@ use std::sync::Arc;
 use collections::HashSet;
 use editor::actions::MoveUp;
 use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
-use file_icons::FileIcons;
 use fs::Fs;
+use git::ExpandCommitEditor;
+use git_ui::git_panel;
 use gpui::{
     Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
     WeakEntity,
 };
 use language_model::LanguageModelRegistry;
 use language_model_selector::ToggleModelSelector;
+use project::Project;
 use rope::Point;
 use settings::Settings;
 use std::time::Duration;
 use text::Bias;
 use theme::ThemeSettings;
 use ui::{
-    prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle,
-    Tooltip,
+    prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip,
 };
+use util::ResultExt;
 use vim_mode_setting::VimModeSetting;
 use workspace::notifications::{NotificationId, NotifyTaskExt};
 use workspace::{Toast, Workspace};
@@ -37,6 +39,7 @@ pub struct MessageEditor {
     thread: Entity<Thread>,
     editor: Entity<Editor>,
     workspace: WeakEntity<Workspace>,
+    project: Entity<Project>,
     context_store: Entity<ContextStore>,
     context_strip: Entity<ContextStrip>,
     context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@@ -44,7 +47,6 @@ pub struct MessageEditor {
     inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
     model_selector: Entity<AssistantModelSelector>,
     tool_selector: Entity<ToolSelector>,
-    edits_expanded: bool,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -107,8 +109,9 @@ impl MessageEditor {
         ];
 
         Self {
-            thread,
             editor: editor.clone(),
+            project: thread.read(cx).project().clone(),
+            thread,
             workspace,
             context_store,
             context_strip,
@@ -125,7 +128,6 @@ impl MessageEditor {
                 )
             }),
             tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
-            edits_expanded: false,
             _subscriptions: subscriptions,
         }
     }
@@ -206,12 +208,15 @@ impl MessageEditor {
 
         let thread = self.thread.clone();
         let context_store = self.context_store.clone();
+        let git_store = self.project.read(cx).git_store();
+        let checkpoint = git_store.read(cx).checkpoint(cx);
         cx.spawn(async move |_, cx| {
             refresh_task.await;
+            let checkpoint = checkpoint.await.log_err();
             thread
                 .update(cx, |thread, cx| {
                     let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
-                    thread.insert_user_message(user_message, context, cx);
+                    thread.insert_user_message(user_message, context, checkpoint, cx);
                     thread.send_to_model(model, request_kind, cx);
                 })
                 .ok();
@@ -347,8 +352,12 @@ impl Render for MessageEditor {
             px(64.)
         };
 
-        let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx);
-        let changed_buffers_count = changed_buffers.len();
+        let project = self.thread.read(cx).project();
+        let changed_files = if let Some(repository) = project.read(cx).active_repository(cx) {
+            repository.read(cx).status().count()
+        } else {
+            0
+        };
 
         v_flex()
             .size_full()
@@ -410,7 +419,7 @@ impl Render for MessageEditor {
                     ),
                 )
             })
-            .when(changed_buffers_count > 0, |parent| {
+            .when(changed_files > 0, |parent| {
                 parent.child(
                     v_flex()
                         .mx_2()
@@ -421,96 +430,60 @@ impl Render for MessageEditor {
                         .rounded_t_md()
                         .child(
                             h_flex()
-                                .gap_2()
+                                .justify_between()
                                 .p_2()
                                 .child(
-                                    Disclosure::new("edits-disclosure", self.edits_expanded)
-                                        .on_click(cx.listener(|this, _ev, _window, cx| {
-                                            this.edits_expanded = !this.edits_expanded;
-                                            cx.notify();
-                                        })),
-                                )
-                                .child(
-                                    Label::new("Edits")
-                                        .size(LabelSize::XSmall)
-                                        .color(Color::Muted),
+                                    h_flex()
+                                        .gap_2()
+                                        .child(
+                                            IconButton::new(
+                                                "edits-disclosure",
+                                                IconName::GitBranchSmall,
+                                            )
+                                            .icon_size(IconSize::Small)
+                                            .on_click(
+                                                |_ev, _window, cx| {
+                                                    cx.defer(|cx| {
+                                                        cx.dispatch_action(&git_panel::ToggleFocus)
+                                                    });
+                                                },
+                                            ),
+                                        )
+                                        .child(
+                                            Label::new(format!(
+                                                "{} {} changed",
+                                                changed_files,
+                                                if changed_files == 1 { "file" } else { "files" }
+                                            ))
+                                            .size(LabelSize::XSmall)
+                                            .color(Color::Muted),
+                                        ),
                                 )
-                                .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted))
                                 .child(
-                                    Label::new(format!(
-                                        "{} {}",
-                                        changed_buffers_count,
-                                        if changed_buffers_count == 1 {
-                                            "file"
-                                        } else {
-                                            "files"
-                                        }
-                                    ))
-                                    .size(LabelSize::XSmall)
-                                    .color(Color::Muted),
-                                ),
-                        )
-                        .when(self.edits_expanded, |parent| {
-                            parent.child(
-                                v_flex().bg(cx.theme().colors().editor_background).children(
-                                    changed_buffers.enumerate().flat_map(|(index, buffer)| {
-                                        let file = buffer.read(cx).file()?;
-                                        let path = file.path();
-
-                                        let parent_label = path.parent().and_then(|parent| {
-                                            let parent_str = parent.to_string_lossy();
-
-                                            if parent_str.is_empty() {
-                                                None
-                                            } else {
-                                                Some(
-                                                    Label::new(format!(
-                                                        "{}{}",
-                                                        parent_str,
-                                                        std::path::MAIN_SEPARATOR_STR
-                                                    ))
-                                                    .color(Color::Muted)
-                                                    .size(LabelSize::Small),
-                                                )
-                                            }
-                                        });
-
-                                        let name_label = path.file_name().map(|name| {
-                                            Label::new(name.to_string_lossy().to_string())
-                                                .size(LabelSize::Small)
-                                        });
-
-                                        let file_icon = FileIcons::get_icon(&path, cx)
-                                            .map(Icon::from_path)
-                                            .unwrap_or_else(|| Icon::new(IconName::File));
-
-                                        let element = div()
-                                            .p_2()
-                                            .when(index + 1 < changed_buffers_count, |parent| {
-                                                parent
-                                                    .border_color(cx.theme().colors().border)
-                                                    .border_b_1()
-                                            })
-                                            .child(
-                                                h_flex()
-                                                    .gap_2()
-                                                    .child(file_icon)
-                                                    .child(
-                                                        // TODO: handle overflow
-                                                        h_flex()
-                                                            .children(parent_label)
-                                                            .children(name_label),
-                                                    )
-                                                    // TODO: show lines changed
-                                                    .child(Label::new("+").color(Color::Created))
-                                                    .child(Label::new("-").color(Color::Deleted)),
-                                            );
-
-                                        Some(element)
-                                    }),
+                                    h_flex()
+                                        .gap_2()
+                                        .child(
+                                            Button::new("review", "Review")
+                                                .label_size(LabelSize::XSmall)
+                                                .on_click(|_event, _window, cx| {
+                                                    cx.defer(|cx| {
+                                                        cx.dispatch_action(
+                                                            &git_ui::project_diff::Diff,
+                                                        );
+                                                    });
+                                                }),
+                                        )
+                                        .child(
+                                            Button::new("commit", "Commit")
+                                                .label_size(LabelSize::XSmall)
+                                                .on_click(|_event, _window, cx| {
+                                                    cx.defer(|cx| {
+                                                        cx.dispatch_action(&ExpandCommitEditor)
+                                                    });
+                                                }),
+                                        ),
                                 ),
-                            )
-                        }),
+                        ),
                 )
             })
             .child(

crates/assistant2/src/thread.rs 🔗

@@ -16,6 +16,7 @@ use language_model::{
     LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
     Role, StopReason, TokenUsage,
 };
+use project::git::GitStoreCheckpoint;
 use project::Project;
 use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
 use scripting_tool::{ScriptingSession, ScriptingTool};
@@ -89,6 +90,12 @@ pub struct GitState {
     pub diff: Option<String>,
 }
 
+#[derive(Clone)]
+pub struct ThreadCheckpoint {
+    message_id: MessageId,
+    git_checkpoint: GitStoreCheckpoint,
+}
+
 /// A thread of conversation with the LLM.
 pub struct Thread {
     id: ThreadId,
@@ -99,6 +106,7 @@ pub struct Thread {
     next_message_id: MessageId,
     context: BTreeMap<ContextId, ContextSnapshot>,
     context_by_message: HashMap<MessageId, Vec<ContextId>>,
+    checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     project: Entity<Project>,
@@ -128,6 +136,7 @@ impl Thread {
             next_message_id: MessageId(0),
             context: BTreeMap::default(),
             context_by_message: HashMap::default(),
+            checkpoints_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
             project: project.clone(),
@@ -188,6 +197,7 @@ impl Thread {
             next_message_id,
             context: BTreeMap::default(),
             context_by_message: HashMap::default(),
+            checkpoints_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
             project,
@@ -249,6 +259,45 @@ impl Thread {
         &self.tools
     }
 
+    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
+        let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
+        Some(ThreadCheckpoint {
+            message_id: id,
+            git_checkpoint: checkpoint,
+        })
+    }
+
+    pub fn restore_checkpoint(
+        &mut self,
+        checkpoint: ThreadCheckpoint,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let project = self.project.read(cx);
+        let restore = project
+            .git_store()
+            .read(cx)
+            .restore_checkpoint(checkpoint.git_checkpoint, cx);
+        cx.spawn(async move |this, cx| {
+            restore.await?;
+            this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
+        })
+    }
+
+    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
+        let Some(message_ix) = self
+            .messages
+            .iter()
+            .rposition(|message| message.id == message_id)
+        else {
+            return;
+        };
+        for deleted_message in self.messages.drain(message_ix..) {
+            self.context_by_message.remove(&deleted_message.id);
+            self.checkpoints_by_message.remove(&deleted_message.id);
+        }
+        cx.notify();
+    }
+
     pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
         let context = self.context_by_message.get(&id)?;
         Some(
@@ -296,13 +345,6 @@ impl Thread {
         self.scripting_tool_use.tool_results_for_message(id)
     }
 
-    pub fn scripting_changed_buffers<'a>(
-        &self,
-        cx: &'a App,
-    ) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
-        self.scripting_session.read(cx).changed_buffers()
-    }
-
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
         self.tool_use.message_has_tool_results(message_id)
     }
@@ -315,6 +357,7 @@ impl Thread {
         &mut self,
         text: impl Into<String>,
         context: Vec<ContextSnapshot>,
+        checkpoint: Option<GitStoreCheckpoint>,
         cx: &mut Context<Self>,
     ) -> MessageId {
         let message_id = self.insert_message(Role::User, text, cx);
@@ -322,6 +365,9 @@ impl Thread {
         self.context
             .extend(context.into_iter().map(|context| (context.id, context)));
         self.context_by_message.insert(message_id, context_ids);
+        if let Some(checkpoint) = checkpoint {
+            self.checkpoints_by_message.insert(message_id, checkpoint);
+        }
         message_id
     }
 
@@ -941,6 +987,7 @@ impl Thread {
             // so for now we provide some text to keep the model on track.
             "Here are the tool results.",
             Vec::new(),
+            None,
             cx,
         );
     }
@@ -1144,6 +1191,10 @@ impl Thread {
         &self.action_log
     }
 
+    pub fn project(&self) -> &Entity<Project> {
+        &self.project
+    }
+
     pub fn cumulative_token_usage(&self) -> TokenUsage {
         self.cumulative_token_usage.clone()
     }

crates/assistant_eval/src/eval.rs 🔗

@@ -82,7 +82,7 @@ impl Eval {
             assistant.update(cx, |assistant, cx| {
                 assistant.thread.update(cx, |thread, cx| {
                     let context = vec![];
-                    thread.insert_user_message(self.user_prompt.clone(), context, cx);
+                    thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
                     thread.send_to_model(model, RequestKind::Chat, cx);
                 });
             })?;

crates/fs/src/fake_git_repo.rs 🔗

@@ -408,4 +408,12 @@ impl GitRepository for FakeGitRepository {
     ) -> BoxFuture<Result<String>> {
         unimplemented!()
     }
+
+    fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture<Result<git::Oid>> {
+        unimplemented!()
+    }
+
+    fn restore_checkpoint(&self, _oid: git::Oid, _cx: AsyncApp) -> BoxFuture<Result<()>> {
+        unimplemented!()
+    }
 }

crates/git/Cargo.toml 🔗

@@ -35,6 +35,7 @@ text.workspace = true
 time.workspace = true
 url.workspace = true
 util.workspace = true
+uuid.workspace = true
 futures.workspace = true
 
 [dev-dependencies]
@@ -43,3 +44,4 @@ serde_json.workspace = true
 text = { workspace = true, features = ["test-support"] }
 unindent.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
+tempfile.workspace = true

crates/git/src/repository.rs 🔗

@@ -1,5 +1,5 @@
 use crate::status::GitStatus;
-use crate::SHORT_SHA_LENGTH;
+use crate::{Oid, SHORT_SHA_LENGTH};
 use anyhow::{anyhow, Context as _, Result};
 use collections::HashMap;
 use futures::future::BoxFuture;
@@ -22,6 +22,7 @@ use std::{
 use sum_tree::MapSeekTarget;
 use util::command::new_smol_command;
 use util::ResultExt;
+use uuid::Uuid;
 
 pub use askpass::{AskPassResult, AskPassSession};
 
@@ -287,6 +288,12 @@ pub trait GitRepository: Send + Sync {
 
     /// Run git diff
     fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture<Result<String>>;
+
+    /// Creates a checkpoint for the repository.
+    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>>;
+
+    /// Resets to a previously-created checkpoint.
+    fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>>;
 }
 
 pub enum DiffType {
@@ -1025,6 +1032,89 @@ impl GitRepository for RealGitRepository {
         })
         .boxed()
     }
+
+    fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>> {
+        let working_directory = self.working_directory();
+        let git_binary_path = self.git_binary_path.clone();
+        let executor = cx.background_executor().clone();
+        cx.background_spawn(async move {
+            let working_directory = working_directory?;
+            let index_file_path = working_directory.join(".git/index.tmp");
+
+            let delete_temp_index = util::defer({
+                let index_file_path = index_file_path.clone();
+                || {
+                    executor
+                        .spawn(async move {
+                            smol::fs::remove_file(index_file_path).await.log_err();
+                        })
+                        .detach();
+                }
+            });
+
+            let run_git_command = async |args: &[&str]| {
+                let output = new_smol_command(&git_binary_path)
+                    .current_dir(&working_directory)
+                    .env("GIT_INDEX_FILE", &index_file_path)
+                    .env("GIT_AUTHOR_NAME", "Zed")
+                    .env("GIT_AUTHOR_EMAIL", "hi@zed.dev")
+                    .env("GIT_COMMITTER_NAME", "Zed")
+                    .env("GIT_COMMITTER_EMAIL", "hi@zed.dev")
+                    .args(args)
+                    .output()
+                    .await?;
+                if output.status.success() {
+                    anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
+                } else {
+                    let error = String::from_utf8_lossy(&output.stderr);
+                    Err(anyhow!("Git command failed: {:?}", error))
+                }
+            };
+
+            run_git_command(&["add", "--all"]).await?;
+            let tree = run_git_command(&["write-tree"]).await?;
+            let commit_sha = run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?;
+            let ref_name = Uuid::new_v4().to_string();
+            run_git_command(&["update-ref", &format!("refs/heads/{ref_name}"), &commit_sha])
+                .await?;
+
+            smol::fs::remove_file(index_file_path).await.ok();
+            delete_temp_index.abort();
+
+            commit_sha.parse()
+        })
+        .boxed()
+    }
+
+    fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>> {
+        let working_directory = self.working_directory();
+        let git_binary_path = self.git_binary_path.clone();
+        cx.background_spawn(async move {
+            let working_directory = working_directory?;
+            let index_file_path = working_directory.join(".git/index.tmp");
+
+            let run_git_command = async |args: &[&str]| {
+                let output = new_smol_command(&git_binary_path)
+                    .current_dir(&working_directory)
+                    .env("GIT_INDEX_FILE", &index_file_path)
+                    .args(args)
+                    .output()
+                    .await?;
+                if output.status.success() {
+                    anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
+                } else {
+                    let error = String::from_utf8_lossy(&output.stderr);
+                    Err(anyhow!("Git command failed: {:?}", error))
+                }
+            };
+
+            run_git_command(&["restore", "--source", &oid.to_string(), "--worktree", "."]).await?;
+            run_git_command(&["read-tree", &oid.to_string()]).await?;
+            run_git_command(&["clean", "-d", "--force"]).await?;
+            Ok(())
+        })
+        .boxed()
+    }
 }
 
 async fn run_remote_command(
@@ -1260,29 +1350,72 @@ fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
     }
 }
 
-#[test]
-fn test_branches_parsing() {
-    // suppress "help: octal escapes are not supported, `\0` is always null"
-    #[allow(clippy::octal_escapes)]
-    let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n";
-    assert_eq!(
-        parse_branch_input(&input).unwrap(),
-        vec![Branch {
-            is_head: true,
-            name: "zed-patches".into(),
-            upstream: Some(Upstream {
-                ref_name: "refs/remotes/origin/zed-patches".into(),
-                tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {
-                    ahead: 0,
-                    behind: 0
+#[cfg(test)]
+mod tests {
+    use gpui::TestAppContext;
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_checkpoint(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let repo_dir = tempfile::tempdir().unwrap();
+        git2::Repository::init(repo_dir.path()).unwrap();
+        let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap();
+
+        smol::fs::write(repo_dir.path().join("foo"), "foo")
+            .await
+            .unwrap();
+        let checkpoint_sha = repo.checkpoint(cx.to_async()).await.unwrap();
+
+        smol::fs::write(repo_dir.path().join("foo"), "bar")
+            .await
+            .unwrap();
+        smol::fs::write(repo_dir.path().join("baz"), "qux")
+            .await
+            .unwrap();
+        repo.restore_checkpoint(checkpoint_sha, cx.to_async())
+            .await
+            .unwrap();
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("foo"))
+                .await
+                .unwrap(),
+            "foo"
+        );
+        assert_eq!(
+            smol::fs::read_to_string(repo_dir.path().join("baz"))
+                .await
+                .ok(),
+            None
+        );
+    }
+
+    #[test]
+    fn test_branches_parsing() {
+        // suppress "help: octal escapes are not supported, `\0` is always null"
+        #[allow(clippy::octal_escapes)]
+        let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n";
+        assert_eq!(
+            parse_branch_input(&input).unwrap(),
+            vec![Branch {
+                is_head: true,
+                name: "zed-patches".into(),
+                upstream: Some(Upstream {
+                    ref_name: "refs/remotes/origin/zed-patches".into(),
+                    tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {
+                        ahead: 0,
+                        behind: 0
+                    })
+                }),
+                most_recent_commit: Some(CommitSummary {
+                    sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(),
+                    subject: "generated protobuf".into(),
+                    commit_timestamp: 1733187470,
+                    has_parent: false,
                 })
-            }),
-            most_recent_commit: Some(CommitSummary {
-                sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(),
-                subject: "generated protobuf".into(),
-                commit_timestamp: 1733187470,
-                has_parent: false,
-            })
-        }]
-    )
+            }]
+        )
+    }
 }

crates/project/src/git.rs 🔗

@@ -11,10 +11,10 @@ use collections::HashMap;
 use fs::Fs;
 use futures::{
     channel::{mpsc, oneshot},
-    future::{OptionFuture, Shared},
+    future::{self, OptionFuture, Shared},
     FutureExt as _, StreamExt as _,
 };
-use git::repository::DiffType;
+use git::{repository::DiffType, Oid};
 use git::{
     repository::{
         Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath,
@@ -117,6 +117,16 @@ enum GitStoreState {
     },
 }
 
+#[derive(Clone)]
+pub struct GitStoreCheckpoint {
+    checkpoints_by_dot_git_abs_path: HashMap<PathBuf, RepositoryCheckpoint>,
+}
+
+#[derive(Copy, Clone)]
+pub struct RepositoryCheckpoint {
+    sha: Oid,
+}
+
 pub struct Repository {
     commit_message_buffer: Option<Entity<Buffer>>,
     git_store: WeakEntity<GitStore>,
@@ -506,6 +516,45 @@ impl GitStore {
         diff_state.read(cx).uncommitted_diff.as_ref()?.upgrade()
     }
 
+    pub fn checkpoint(&self, cx: &App) -> Task<Result<GitStoreCheckpoint>> {
+        let mut dot_git_abs_paths = Vec::new();
+        let mut checkpoints = Vec::new();
+        for repository in self.repositories.values() {
+            let repository = repository.read(cx);
+            dot_git_abs_paths.push(repository.dot_git_abs_path.clone());
+            checkpoints.push(repository.checkpoint().map(|checkpoint| checkpoint?));
+        }
+
+        cx.background_executor().spawn(async move {
+            let checkpoints: Vec<RepositoryCheckpoint> = future::try_join_all(checkpoints).await?;
+            Ok(GitStoreCheckpoint {
+                checkpoints_by_dot_git_abs_path: dot_git_abs_paths
+                    .into_iter()
+                    .zip(checkpoints)
+                    .collect(),
+            })
+        })
+    }
+
+    pub fn restore_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task<Result<()>> {
+        let repositories_by_dot_git_abs_path = self
+            .repositories
+            .values()
+            .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo))
+            .collect::<HashMap<_, _>>();
+
+        let mut tasks = Vec::new();
+        for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
+            if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
+                tasks.push(repository.read(cx).restore_checkpoint(checkpoint));
+            }
+        }
+        cx.background_spawn(async move {
+            future::try_join_all(tasks).await?;
+            Ok(())
+        })
+    }
+
     fn downstream_client(&self) -> Option<(AnyProtoClient, ProjectId)> {
         match &self.state {
             GitStoreState::Local {
@@ -2922,4 +2971,30 @@ impl Repository {
             }
         })
     }
+
+    pub fn checkpoint(&self) -> oneshot::Receiver<Result<RepositoryCheckpoint>> {
+        self.send_job(|repo, cx| async move {
+            match repo {
+                GitRepo::Local(git_repository) => {
+                    let sha = git_repository.checkpoint(cx).await?;
+                    Ok(RepositoryCheckpoint { sha })
+                }
+                GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
+            }
+        })
+    }
+
+    pub fn restore_checkpoint(
+        &self,
+        checkpoint: RepositoryCheckpoint,
+    ) -> oneshot::Receiver<Result<()>> {
+        self.send_job(move |repo, cx| async move {
+            match repo {
+                GitRepo::Local(git_repository) => {
+                    git_repository.restore_checkpoint(checkpoint.sha, cx).await
+                }
+                GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
+            }
+        })
+    }
 }