From 33faa66e35323b6f51a3feace2c345cff57b6b63 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Mar 2025 20:00:21 +0100 Subject: [PATCH] Start on a Git-based review flow (#27103) Release Notes: - N/A --------- Co-authored-by: Nathan Sobo --- Cargo.lock | 3 + crates/assistant2/Cargo.toml | 1 + crates/assistant2/src/active_thread.rs | 21 ++- crates/assistant2/src/message_editor.rs | 163 +++++++++------------ crates/assistant2/src/thread.rs | 65 ++++++++- crates/assistant_eval/src/eval.rs | 2 +- crates/fs/src/fake_git_repo.rs | 8 ++ crates/git/Cargo.toml | 2 + crates/git/src/repository.rs | 183 ++++++++++++++++++++---- crates/project/src/git.rs | 79 +++++++++- 10 files changed, 396 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c6280c7fb3608918ffb70642b11880ac8ec8bfab..4e9e69b7baec529184d9a3858316cf4af4c6a7b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -467,6 +467,7 @@ dependencies = [ "futures 0.3.31", "fuzzy", "git", + "git_ui", "gpui", "heed", "html_to_markdown", @@ -5601,11 +5602,13 @@ dependencies = [ "serde_json", "smol", "sum_tree", + "tempfile", "text", "time", "unindent", "url", "util", + "uuid", ] [[package]] diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index df79bf77a45e7475557a4c54ee38723407c286f6..f14b6c184a8c807eb337ee40720570fd6fa4c9f1 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -39,6 +39,7 @@ fs.workspace = true futures.workspace = true fuzzy.workspace = true git.workspace = true +git_ui.workspace = true gpui.workspace = true heed.workspace = true html_to_markdown.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 08e55e3dcea1ac86880e845cc32917e5620b3a91..3d4ae49928f0febce829970d63be25da02df816a 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -550,6 +550,7 @@ impl ActiveThread { let thread = self.thread.read(cx); // Get all the data we need from thread before we start using it in closures + let checkpoint = thread.checkpoint_for_message(message_id); let context = thread.context_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id); let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); @@ -734,7 +735,25 @@ impl ActiveThread { ), }; - styled_message.into_any() + v_flex() + .when_some(checkpoint, |parent, checkpoint| { + parent.child( + h_flex().pl_2().child( + Button::new("restore-checkpoint", "Restore Checkpoint") + .icon(IconName::Undo) + .size(ButtonSize::Compact) + .on_click(cx.listener(move |this, _, _window, cx| { + this.thread.update(cx, |thread, cx| { + thread + .restore_checkpoint(checkpoint.clone(), cx) + .detach_and_log_err(cx); + }); + })), + ), + ) + }) + .child(styled_message) + .into_any() } fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context) -> impl IntoElement { diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index a095e8bb3eb2faa9451765d8dcc7b43a35bbe80d..ae986743a32814ccfd25eeab43c011f83f79fff3 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -3,23 +3,25 @@ use std::sync::Arc; use collections::HashSet; use editor::actions::MoveUp; use editor::{Editor, EditorElement, EditorEvent, EditorStyle}; -use file_icons::FileIcons; use fs::Fs; +use git::ExpandCommitEditor; +use git_ui::git_panel; use gpui::{ Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle, WeakEntity, }; use language_model::LanguageModelRegistry; use language_model_selector::ToggleModelSelector; +use project::Project; use rope::Point; use settings::Settings; use std::time::Duration; use text::Bias; use theme::ThemeSettings; use ui::{ - prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, - Tooltip, + prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip, }; +use util::ResultExt; use vim_mode_setting::VimModeSetting; use workspace::notifications::{NotificationId, NotifyTaskExt}; use workspace::{Toast, Workspace}; @@ -37,6 +39,7 @@ pub struct MessageEditor { thread: Entity, editor: Entity, workspace: WeakEntity, + project: Entity, context_store: Entity, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, @@ -44,7 +47,6 @@ pub struct MessageEditor { inline_context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, tool_selector: Entity, - edits_expanded: bool, _subscriptions: Vec, } @@ -107,8 +109,9 @@ impl MessageEditor { ]; Self { - thread, editor: editor.clone(), + project: thread.read(cx).project().clone(), + thread, workspace, context_store, context_strip, @@ -125,7 +128,6 @@ impl MessageEditor { ) }), tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)), - edits_expanded: false, _subscriptions: subscriptions, } } @@ -206,12 +208,15 @@ impl MessageEditor { let thread = self.thread.clone(); let context_store = self.context_store.clone(); + let git_store = self.project.read(cx).git_store(); + let checkpoint = git_store.read(cx).checkpoint(cx); cx.spawn(async move |_, cx| { refresh_task.await; + let checkpoint = checkpoint.await.log_err(); thread .update(cx, |thread, cx| { let context = context_store.read(cx).snapshot(cx).collect::>(); - thread.insert_user_message(user_message, context, cx); + thread.insert_user_message(user_message, context, checkpoint, cx); thread.send_to_model(model, request_kind, cx); }) .ok(); @@ -347,8 +352,12 @@ impl Render for MessageEditor { px(64.) }; - let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx); - let changed_buffers_count = changed_buffers.len(); + let project = self.thread.read(cx).project(); + let changed_files = if let Some(repository) = project.read(cx).active_repository(cx) { + repository.read(cx).status().count() + } else { + 0 + }; v_flex() .size_full() @@ -410,7 +419,7 @@ impl Render for MessageEditor { ), ) }) - .when(changed_buffers_count > 0, |parent| { + .when(changed_files > 0, |parent| { parent.child( v_flex() .mx_2() @@ -421,96 +430,60 @@ impl Render for MessageEditor { .rounded_t_md() .child( h_flex() - .gap_2() + .justify_between() .p_2() .child( - Disclosure::new("edits-disclosure", self.edits_expanded) - .on_click(cx.listener(|this, _ev, _window, cx| { - this.edits_expanded = !this.edits_expanded; - cx.notify(); - })), - ) - .child( - Label::new("Edits") - .size(LabelSize::XSmall) - .color(Color::Muted), + h_flex() + .gap_2() + .child( + IconButton::new( + "edits-disclosure", + IconName::GitBranchSmall, + ) + .icon_size(IconSize::Small) + .on_click( + |_ev, _window, cx| { + cx.defer(|cx| { + cx.dispatch_action(&git_panel::ToggleFocus) + }); + }, + ), + ) + .child( + Label::new(format!( + "{} {} changed", + changed_files, + if changed_files == 1 { "file" } else { "files" } + )) + .size(LabelSize::XSmall) + .color(Color::Muted), + ), ) - .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted)) .child( - Label::new(format!( - "{} {}", - changed_buffers_count, - if changed_buffers_count == 1 { - "file" - } else { - "files" - } - )) - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .when(self.edits_expanded, |parent| { - parent.child( - v_flex().bg(cx.theme().colors().editor_background).children( - changed_buffers.enumerate().flat_map(|(index, buffer)| { - let file = buffer.read(cx).file()?; - let path = file.path(); - - let parent_label = path.parent().and_then(|parent| { - let parent_str = parent.to_string_lossy(); - - if parent_str.is_empty() { - None - } else { - Some( - Label::new(format!( - "{}{}", - parent_str, - std::path::MAIN_SEPARATOR_STR - )) - .color(Color::Muted) - .size(LabelSize::Small), - ) - } - }); - - let name_label = path.file_name().map(|name| { - Label::new(name.to_string_lossy().to_string()) - .size(LabelSize::Small) - }); - - let file_icon = FileIcons::get_icon(&path, cx) - .map(Icon::from_path) - .unwrap_or_else(|| Icon::new(IconName::File)); - - let element = div() - .p_2() - .when(index + 1 < changed_buffers_count, |parent| { - parent - .border_color(cx.theme().colors().border) - .border_b_1() - }) - .child( - h_flex() - .gap_2() - .child(file_icon) - .child( - // TODO: handle overflow - h_flex() - .children(parent_label) - .children(name_label), - ) - // TODO: show lines changed - .child(Label::new("+").color(Color::Created)) - .child(Label::new("-").color(Color::Deleted)), - ); - - Some(element) - }), + h_flex() + .gap_2() + .child( + Button::new("review", "Review") + .label_size(LabelSize::XSmall) + .on_click(|_event, _window, cx| { + cx.defer(|cx| { + cx.dispatch_action( + &git_ui::project_diff::Diff, + ); + }); + }), + ) + .child( + Button::new("commit", "Commit") + .label_size(LabelSize::XSmall) + .on_click(|_event, _window, cx| { + cx.defer(|cx| { + cx.dispatch_action(&ExpandCommitEditor) + }); + }), + ), ), - ) - }), + ), ) }) .child( diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 8f65a1d1b36f6939db41c16ec0e27112efa99421..98daa6fbcb4e401a4a22037fd76ed32f0d382de5 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -16,6 +16,7 @@ use language_model::{ LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason, TokenUsage, }; +use project::git::GitStoreCheckpoint; use project::Project; use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; use scripting_tool::{ScriptingSession, ScriptingTool}; @@ -89,6 +90,12 @@ pub struct GitState { pub diff: Option, } +#[derive(Clone)] +pub struct ThreadCheckpoint { + message_id: MessageId, + git_checkpoint: GitStoreCheckpoint, +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -99,6 +106,7 @@ pub struct Thread { next_message_id: MessageId, context: BTreeMap, context_by_message: HashMap>, + checkpoints_by_message: HashMap, completion_count: usize, pending_completions: Vec, project: Entity, @@ -128,6 +136,7 @@ impl Thread { next_message_id: MessageId(0), context: BTreeMap::default(), context_by_message: HashMap::default(), + checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), project: project.clone(), @@ -188,6 +197,7 @@ impl Thread { next_message_id, context: BTreeMap::default(), context_by_message: HashMap::default(), + checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), project, @@ -249,6 +259,45 @@ impl Thread { &self.tools } + pub fn checkpoint_for_message(&self, id: MessageId) -> Option { + let checkpoint = self.checkpoints_by_message.get(&id).cloned()?; + Some(ThreadCheckpoint { + message_id: id, + git_checkpoint: checkpoint, + }) + } + + pub fn restore_checkpoint( + &mut self, + checkpoint: ThreadCheckpoint, + cx: &mut Context, + ) -> Task> { + let project = self.project.read(cx); + let restore = project + .git_store() + .read(cx) + .restore_checkpoint(checkpoint.git_checkpoint, cx); + cx.spawn(async move |this, cx| { + restore.await?; + this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx)) + }) + } + + pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context) { + let Some(message_ix) = self + .messages + .iter() + .rposition(|message| message.id == message_id) + else { + return; + }; + for deleted_message in self.messages.drain(message_ix..) { + self.context_by_message.remove(&deleted_message.id); + self.checkpoints_by_message.remove(&deleted_message.id); + } + cx.notify(); + } + pub fn context_for_message(&self, id: MessageId) -> Option> { let context = self.context_by_message.get(&id)?; Some( @@ -296,13 +345,6 @@ impl Thread { self.scripting_tool_use.tool_results_for_message(id) } - pub fn scripting_changed_buffers<'a>( - &self, - cx: &'a App, - ) -> impl ExactSizeIterator> { - self.scripting_session.read(cx).changed_buffers() - } - pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_use.message_has_tool_results(message_id) } @@ -315,6 +357,7 @@ impl Thread { &mut self, text: impl Into, context: Vec, + checkpoint: Option, cx: &mut Context, ) -> MessageId { let message_id = self.insert_message(Role::User, text, cx); @@ -322,6 +365,9 @@ impl Thread { self.context .extend(context.into_iter().map(|context| (context.id, context))); self.context_by_message.insert(message_id, context_ids); + if let Some(checkpoint) = checkpoint { + self.checkpoints_by_message.insert(message_id, checkpoint); + } message_id } @@ -941,6 +987,7 @@ impl Thread { // so for now we provide some text to keep the model on track. "Here are the tool results.", Vec::new(), + None, cx, ); } @@ -1144,6 +1191,10 @@ impl Thread { &self.action_log } + pub fn project(&self) -> &Entity { + &self.project + } + pub fn cumulative_token_usage(&self) -> TokenUsage { self.cumulative_token_usage.clone() } diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs index 5ce7c02d8efe6694e48b6e49a81f4877bb4b3b0c..8f5def88e35aecc21614f87c2965ea3905ba21a8 100644 --- a/crates/assistant_eval/src/eval.rs +++ b/crates/assistant_eval/src/eval.rs @@ -82,7 +82,7 @@ impl Eval { assistant.update(cx, |assistant, cx| { assistant.thread.update(cx, |thread, cx| { let context = vec![]; - thread.insert_user_message(self.user_prompt.clone(), context, cx); + thread.insert_user_message(self.user_prompt.clone(), context, None, cx); thread.send_to_model(model, RequestKind::Chat, cx); }); })?; diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 6c2c230b5ba7fc46acd090ae16bf86fbce88c294..89633158382bb13cf231bbdb43ec29be31ca287f 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -408,4 +408,12 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture> { unimplemented!() } + + fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture> { + unimplemented!() + } + + fn restore_checkpoint(&self, _oid: git::Oid, _cx: AsyncApp) -> BoxFuture> { + unimplemented!() + } } diff --git a/crates/git/Cargo.toml b/crates/git/Cargo.toml index c32fe3491f72b8bff369f07fe54970e1e38406bb..23e145c0c6775a45057cb73a94f3dd4e80b8448e 100644 --- a/crates/git/Cargo.toml +++ b/crates/git/Cargo.toml @@ -35,6 +35,7 @@ text.workspace = true time.workspace = true url.workspace = true util.workspace = true +uuid.workspace = true futures.workspace = true [dev-dependencies] @@ -43,3 +44,4 @@ serde_json.workspace = true text = { workspace = true, features = ["test-support"] } unindent.workspace = true gpui = { workspace = true, features = ["test-support"] } +tempfile.workspace = true diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 8bd43fe95954186cfcba6aa89e2530e7611ac444..4286f0eba320c0d513693030bc145cee5e968503 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -1,5 +1,5 @@ use crate::status::GitStatus; -use crate::SHORT_SHA_LENGTH; +use crate::{Oid, SHORT_SHA_LENGTH}; use anyhow::{anyhow, Context as _, Result}; use collections::HashMap; use futures::future::BoxFuture; @@ -22,6 +22,7 @@ use std::{ use sum_tree::MapSeekTarget; use util::command::new_smol_command; use util::ResultExt; +use uuid::Uuid; pub use askpass::{AskPassResult, AskPassSession}; @@ -287,6 +288,12 @@ pub trait GitRepository: Send + Sync { /// Run git diff fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture>; + + /// Creates a checkpoint for the repository. + fn checkpoint(&self, cx: AsyncApp) -> BoxFuture>; + + /// Resets to a previously-created checkpoint. + fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture>; } pub enum DiffType { @@ -1025,6 +1032,89 @@ impl GitRepository for RealGitRepository { }) .boxed() } + + fn checkpoint(&self, cx: AsyncApp) -> BoxFuture> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { + let working_directory = working_directory?; + let index_file_path = working_directory.join(".git/index.tmp"); + + let delete_temp_index = util::defer({ + let index_file_path = index_file_path.clone(); + || { + executor + .spawn(async move { + smol::fs::remove_file(index_file_path).await.log_err(); + }) + .detach(); + } + }); + + let run_git_command = async |args: &[&str]| { + let output = new_smol_command(&git_binary_path) + .current_dir(&working_directory) + .env("GIT_INDEX_FILE", &index_file_path) + .env("GIT_AUTHOR_NAME", "Zed") + .env("GIT_AUTHOR_EMAIL", "hi@zed.dev") + .env("GIT_COMMITTER_NAME", "Zed") + .env("GIT_COMMITTER_EMAIL", "hi@zed.dev") + .args(args) + .output() + .await?; + if output.status.success() { + anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string()) + } else { + let error = String::from_utf8_lossy(&output.stderr); + Err(anyhow!("Git command failed: {:?}", error)) + } + }; + + run_git_command(&["add", "--all"]).await?; + let tree = run_git_command(&["write-tree"]).await?; + let commit_sha = run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?; + let ref_name = Uuid::new_v4().to_string(); + run_git_command(&["update-ref", &format!("refs/heads/{ref_name}"), &commit_sha]) + .await?; + + smol::fs::remove_file(index_file_path).await.ok(); + delete_temp_index.abort(); + + commit_sha.parse() + }) + .boxed() + } + + fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + cx.background_spawn(async move { + let working_directory = working_directory?; + let index_file_path = working_directory.join(".git/index.tmp"); + + let run_git_command = async |args: &[&str]| { + let output = new_smol_command(&git_binary_path) + .current_dir(&working_directory) + .env("GIT_INDEX_FILE", &index_file_path) + .args(args) + .output() + .await?; + if output.status.success() { + anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string()) + } else { + let error = String::from_utf8_lossy(&output.stderr); + Err(anyhow!("Git command failed: {:?}", error)) + } + }; + + run_git_command(&["restore", "--source", &oid.to_string(), "--worktree", "."]).await?; + run_git_command(&["read-tree", &oid.to_string()]).await?; + run_git_command(&["clean", "-d", "--force"]).await?; + Ok(()) + }) + .boxed() + } } async fn run_remote_command( @@ -1260,29 +1350,72 @@ fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> { } } -#[test] -fn test_branches_parsing() { - // suppress "help: octal escapes are not supported, `\0` is always null" - #[allow(clippy::octal_escapes)] - let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n"; - assert_eq!( - parse_branch_input(&input).unwrap(), - vec![Branch { - is_head: true, - name: "zed-patches".into(), - upstream: Some(Upstream { - ref_name: "refs/remotes/origin/zed-patches".into(), - tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus { - ahead: 0, - behind: 0 +#[cfg(test)] +mod tests { + use gpui::TestAppContext; + + use super::*; + + #[gpui::test] + async fn test_checkpoint(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let repo_dir = tempfile::tempdir().unwrap(); + git2::Repository::init(repo_dir.path()).unwrap(); + let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap(); + + smol::fs::write(repo_dir.path().join("foo"), "foo") + .await + .unwrap(); + let checkpoint_sha = repo.checkpoint(cx.to_async()).await.unwrap(); + + smol::fs::write(repo_dir.path().join("foo"), "bar") + .await + .unwrap(); + smol::fs::write(repo_dir.path().join("baz"), "qux") + .await + .unwrap(); + repo.restore_checkpoint(checkpoint_sha, cx.to_async()) + .await + .unwrap(); + assert_eq!( + smol::fs::read_to_string(repo_dir.path().join("foo")) + .await + .unwrap(), + "foo" + ); + assert_eq!( + smol::fs::read_to_string(repo_dir.path().join("baz")) + .await + .ok(), + None + ); + } + + #[test] + fn test_branches_parsing() { + // suppress "help: octal escapes are not supported, `\0` is always null" + #[allow(clippy::octal_escapes)] + let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n"; + assert_eq!( + parse_branch_input(&input).unwrap(), + vec![Branch { + is_head: true, + name: "zed-patches".into(), + upstream: Some(Upstream { + ref_name: "refs/remotes/origin/zed-patches".into(), + tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus { + ahead: 0, + behind: 0 + }) + }), + most_recent_commit: Some(CommitSummary { + sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(), + subject: "generated protobuf".into(), + commit_timestamp: 1733187470, + has_parent: false, }) - }), - most_recent_commit: Some(CommitSummary { - sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(), - subject: "generated protobuf".into(), - commit_timestamp: 1733187470, - has_parent: false, - }) - }] - ) + }] + ) + } } diff --git a/crates/project/src/git.rs b/crates/project/src/git.rs index c7afc90228ae73ed141abc98ebfdec711f0dfdfe..bb63d66423934d1763d61e20a77b23b7d4091471 100644 --- a/crates/project/src/git.rs +++ b/crates/project/src/git.rs @@ -11,10 +11,10 @@ use collections::HashMap; use fs::Fs; use futures::{ channel::{mpsc, oneshot}, - future::{OptionFuture, Shared}, + future::{self, OptionFuture, Shared}, FutureExt as _, StreamExt as _, }; -use git::repository::DiffType; +use git::{repository::DiffType, Oid}; use git::{ repository::{ Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath, @@ -117,6 +117,16 @@ enum GitStoreState { }, } +#[derive(Clone)] +pub struct GitStoreCheckpoint { + checkpoints_by_dot_git_abs_path: HashMap, +} + +#[derive(Copy, Clone)] +pub struct RepositoryCheckpoint { + sha: Oid, +} + pub struct Repository { commit_message_buffer: Option>, git_store: WeakEntity, @@ -506,6 +516,45 @@ impl GitStore { diff_state.read(cx).uncommitted_diff.as_ref()?.upgrade() } + pub fn checkpoint(&self, cx: &App) -> Task> { + let mut dot_git_abs_paths = Vec::new(); + let mut checkpoints = Vec::new(); + for repository in self.repositories.values() { + let repository = repository.read(cx); + dot_git_abs_paths.push(repository.dot_git_abs_path.clone()); + checkpoints.push(repository.checkpoint().map(|checkpoint| checkpoint?)); + } + + cx.background_executor().spawn(async move { + let checkpoints: Vec = future::try_join_all(checkpoints).await?; + Ok(GitStoreCheckpoint { + checkpoints_by_dot_git_abs_path: dot_git_abs_paths + .into_iter() + .zip(checkpoints) + .collect(), + }) + }) + } + + pub fn restore_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task> { + let repositories_by_dot_git_abs_path = self + .repositories + .values() + .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo)) + .collect::>(); + + let mut tasks = Vec::new(); + for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path { + if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) { + tasks.push(repository.read(cx).restore_checkpoint(checkpoint)); + } + } + cx.background_spawn(async move { + future::try_join_all(tasks).await?; + Ok(()) + }) + } + fn downstream_client(&self) -> Option<(AnyProtoClient, ProjectId)> { match &self.state { GitStoreState::Local { @@ -2922,4 +2971,30 @@ impl Repository { } }) } + + pub fn checkpoint(&self) -> oneshot::Receiver> { + self.send_job(|repo, cx| async move { + match repo { + GitRepo::Local(git_repository) => { + let sha = git_repository.checkpoint(cx).await?; + Ok(RepositoryCheckpoint { sha }) + } + GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")), + } + }) + } + + pub fn restore_checkpoint( + &self, + checkpoint: RepositoryCheckpoint, + ) -> oneshot::Receiver> { + self.send_job(move |repo, cx| async move { + match repo { + GitRepo::Local(git_repository) => { + git_repository.restore_checkpoint(checkpoint.sha, cx).await + } + GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")), + } + }) + } }