Detailed changes
@@ -467,6 +467,7 @@ dependencies = [
"futures 0.3.31",
"fuzzy",
"git",
+ "git_ui",
"gpui",
"heed",
"html_to_markdown",
@@ -5601,11 +5602,13 @@ dependencies = [
"serde_json",
"smol",
"sum_tree",
+ "tempfile",
"text",
"time",
"unindent",
"url",
"util",
+ "uuid",
]
[[package]]
@@ -39,6 +39,7 @@ fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
git.workspace = true
+git_ui.workspace = true
gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true
@@ -550,6 +550,7 @@ impl ActiveThread {
let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures
+ let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
@@ -734,7 +735,25 @@ impl ActiveThread {
),
};
- styled_message.into_any()
+ v_flex()
+ .when_some(checkpoint, |parent, checkpoint| {
+ parent.child(
+ h_flex().pl_2().child(
+ Button::new("restore-checkpoint", "Restore Checkpoint")
+ .icon(IconName::Undo)
+ .size(ButtonSize::Compact)
+ .on_click(cx.listener(move |this, _, _window, cx| {
+ this.thread.update(cx, |thread, cx| {
+ thread
+ .restore_checkpoint(checkpoint.clone(), cx)
+ .detach_and_log_err(cx);
+ });
+ })),
+ ),
+ )
+ })
+ .child(styled_message)
+ .into_any()
}
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
@@ -3,23 +3,25 @@ use std::sync::Arc;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
-use file_icons::FileIcons;
use fs::Fs;
+use git::ExpandCommitEditor;
+use git_ui::git_panel;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
WeakEntity,
};
use language_model::LanguageModelRegistry;
use language_model_selector::ToggleModelSelector;
+use project::Project;
use rope::Point;
use settings::Settings;
use std::time::Duration;
use text::Bias;
use theme::ThemeSettings;
use ui::{
- prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle,
- Tooltip,
+ prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip,
};
+use util::ResultExt;
use vim_mode_setting::VimModeSetting;
use workspace::notifications::{NotificationId, NotifyTaskExt};
use workspace::{Toast, Workspace};
@@ -37,6 +39,7 @@ pub struct MessageEditor {
thread: Entity<Thread>,
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
+ project: Entity<Project>,
context_store: Entity<ContextStore>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@@ -44,7 +47,6 @@ pub struct MessageEditor {
inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
tool_selector: Entity<ToolSelector>,
- edits_expanded: bool,
_subscriptions: Vec<Subscription>,
}
@@ -107,8 +109,9 @@ impl MessageEditor {
];
Self {
- thread,
editor: editor.clone(),
+ project: thread.read(cx).project().clone(),
+ thread,
workspace,
context_store,
context_strip,
@@ -125,7 +128,6 @@ impl MessageEditor {
)
}),
tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
- edits_expanded: false,
_subscriptions: subscriptions,
}
}
@@ -206,12 +208,15 @@ impl MessageEditor {
let thread = self.thread.clone();
let context_store = self.context_store.clone();
+ let git_store = self.project.read(cx).git_store();
+ let checkpoint = git_store.read(cx).checkpoint(cx);
cx.spawn(async move |_, cx| {
refresh_task.await;
+ let checkpoint = checkpoint.await.log_err();
thread
.update(cx, |thread, cx| {
let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
- thread.insert_user_message(user_message, context, cx);
+ thread.insert_user_message(user_message, context, checkpoint, cx);
thread.send_to_model(model, request_kind, cx);
})
.ok();
@@ -347,8 +352,12 @@ impl Render for MessageEditor {
px(64.)
};
- let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx);
- let changed_buffers_count = changed_buffers.len();
+ let project = self.thread.read(cx).project();
+ let changed_files = if let Some(repository) = project.read(cx).active_repository(cx) {
+ repository.read(cx).status().count()
+ } else {
+ 0
+ };
v_flex()
.size_full()
@@ -410,7 +419,7 @@ impl Render for MessageEditor {
),
)
})
- .when(changed_buffers_count > 0, |parent| {
+ .when(changed_files > 0, |parent| {
parent.child(
v_flex()
.mx_2()
@@ -421,96 +430,60 @@ impl Render for MessageEditor {
.rounded_t_md()
.child(
h_flex()
- .gap_2()
+ .justify_between()
.p_2()
.child(
- Disclosure::new("edits-disclosure", self.edits_expanded)
- .on_click(cx.listener(|this, _ev, _window, cx| {
- this.edits_expanded = !this.edits_expanded;
- cx.notify();
- })),
- )
- .child(
- Label::new("Edits")
- .size(LabelSize::XSmall)
- .color(Color::Muted),
+ h_flex()
+ .gap_2()
+ .child(
+ IconButton::new(
+ "edits-disclosure",
+ IconName::GitBranchSmall,
+ )
+ .icon_size(IconSize::Small)
+ .on_click(
+ |_ev, _window, cx| {
+ cx.defer(|cx| {
+ cx.dispatch_action(&git_panel::ToggleFocus)
+ });
+ },
+ ),
+ )
+ .child(
+ Label::new(format!(
+ "{} {} changed",
+ changed_files,
+ if changed_files == 1 { "file" } else { "files" }
+ ))
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ ),
)
- .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted))
.child(
- Label::new(format!(
- "{} {}",
- changed_buffers_count,
- if changed_buffers_count == 1 {
- "file"
- } else {
- "files"
- }
- ))
- .size(LabelSize::XSmall)
- .color(Color::Muted),
- ),
- )
- .when(self.edits_expanded, |parent| {
- parent.child(
- v_flex().bg(cx.theme().colors().editor_background).children(
- changed_buffers.enumerate().flat_map(|(index, buffer)| {
- let file = buffer.read(cx).file()?;
- let path = file.path();
-
- let parent_label = path.parent().and_then(|parent| {
- let parent_str = parent.to_string_lossy();
-
- if parent_str.is_empty() {
- None
- } else {
- Some(
- Label::new(format!(
- "{}{}",
- parent_str,
- std::path::MAIN_SEPARATOR_STR
- ))
- .color(Color::Muted)
- .size(LabelSize::Small),
- )
- }
- });
-
- let name_label = path.file_name().map(|name| {
- Label::new(name.to_string_lossy().to_string())
- .size(LabelSize::Small)
- });
-
- let file_icon = FileIcons::get_icon(&path, cx)
- .map(Icon::from_path)
- .unwrap_or_else(|| Icon::new(IconName::File));
-
- let element = div()
- .p_2()
- .when(index + 1 < changed_buffers_count, |parent| {
- parent
- .border_color(cx.theme().colors().border)
- .border_b_1()
- })
- .child(
- h_flex()
- .gap_2()
- .child(file_icon)
- .child(
- // TODO: handle overflow
- h_flex()
- .children(parent_label)
- .children(name_label),
- )
- // TODO: show lines changed
- .child(Label::new("+").color(Color::Created))
- .child(Label::new("-").color(Color::Deleted)),
- );
-
- Some(element)
- }),
+ h_flex()
+ .gap_2()
+ .child(
+ Button::new("review", "Review")
+ .label_size(LabelSize::XSmall)
+ .on_click(|_event, _window, cx| {
+ cx.defer(|cx| {
+ cx.dispatch_action(
+ &git_ui::project_diff::Diff,
+ );
+ });
+ }),
+ )
+ .child(
+ Button::new("commit", "Commit")
+ .label_size(LabelSize::XSmall)
+ .on_click(|_event, _window, cx| {
+ cx.defer(|cx| {
+ cx.dispatch_action(&ExpandCommitEditor)
+ });
+ }),
+ ),
),
- )
- }),
+ ),
)
})
.child(
@@ -16,6 +16,7 @@ use language_model::{
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
};
+use project::git::GitStoreCheckpoint;
use project::Project;
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
use scripting_tool::{ScriptingSession, ScriptingTool};
@@ -89,6 +90,12 @@ pub struct GitState {
pub diff: Option<String>,
}
+#[derive(Clone)]
+pub struct ThreadCheckpoint {
+ message_id: MessageId,
+ git_checkpoint: GitStoreCheckpoint,
+}
+
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -99,6 +106,7 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, ContextSnapshot>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
+ checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
@@ -128,6 +136,7 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
+ checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
@@ -188,6 +197,7 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
+ checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project,
@@ -249,6 +259,45 @@ impl Thread {
&self.tools
}
+ pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
+ let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
+ Some(ThreadCheckpoint {
+ message_id: id,
+ git_checkpoint: checkpoint,
+ })
+ }
+
+ pub fn restore_checkpoint(
+ &mut self,
+ checkpoint: ThreadCheckpoint,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let project = self.project.read(cx);
+ let restore = project
+ .git_store()
+ .read(cx)
+ .restore_checkpoint(checkpoint.git_checkpoint, cx);
+ cx.spawn(async move |this, cx| {
+ restore.await?;
+ this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
+ })
+ }
+
+ pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
+ let Some(message_ix) = self
+ .messages
+ .iter()
+ .rposition(|message| message.id == message_id)
+ else {
+ return;
+ };
+ for deleted_message in self.messages.drain(message_ix..) {
+ self.context_by_message.remove(&deleted_message.id);
+ self.checkpoints_by_message.remove(&deleted_message.id);
+ }
+ cx.notify();
+ }
+
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
let context = self.context_by_message.get(&id)?;
Some(
@@ -296,13 +345,6 @@ impl Thread {
self.scripting_tool_use.tool_results_for_message(id)
}
- pub fn scripting_changed_buffers<'a>(
- &self,
- cx: &'a App,
- ) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
- self.scripting_session.read(cx).changed_buffers()
- }
-
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@@ -315,6 +357,7 @@ impl Thread {
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
+ checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx);
@@ -322,6 +365,9 @@ impl Thread {
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
+ if let Some(checkpoint) = checkpoint {
+ self.checkpoints_by_message.insert(message_id, checkpoint);
+ }
message_id
}
@@ -941,6 +987,7 @@ impl Thread {
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
+ None,
cx,
);
}
@@ -1144,6 +1191,10 @@ impl Thread {
&self.action_log
}
+ pub fn project(&self) -> &Entity<Project> {
+ &self.project
+ }
+
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
@@ -82,7 +82,7 @@ impl Eval {
assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
- thread.insert_user_message(self.user_prompt.clone(), context, cx);
+ thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
@@ -408,4 +408,12 @@ impl GitRepository for FakeGitRepository {
) -> BoxFuture<Result<String>> {
unimplemented!()
}
+
+ fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture<Result<git::Oid>> {
+ unimplemented!()
+ }
+
+ fn restore_checkpoint(&self, _oid: git::Oid, _cx: AsyncApp) -> BoxFuture<Result<()>> {
+ unimplemented!()
+ }
}
@@ -35,6 +35,7 @@ text.workspace = true
time.workspace = true
url.workspace = true
util.workspace = true
+uuid.workspace = true
futures.workspace = true
[dev-dependencies]
@@ -43,3 +44,4 @@ serde_json.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] }
+tempfile.workspace = true
@@ -1,5 +1,5 @@
use crate::status::GitStatus;
-use crate::SHORT_SHA_LENGTH;
+use crate::{Oid, SHORT_SHA_LENGTH};
use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
use futures::future::BoxFuture;
@@ -22,6 +22,7 @@ use std::{
use sum_tree::MapSeekTarget;
use util::command::new_smol_command;
use util::ResultExt;
+use uuid::Uuid;
pub use askpass::{AskPassResult, AskPassSession};
@@ -287,6 +288,12 @@ pub trait GitRepository: Send + Sync {
/// Run git diff
fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture<Result<String>>;
+
+ /// Creates a checkpoint for the repository.
+ fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>>;
+
+ /// Resets to a previously-created checkpoint.
+ fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>>;
}
pub enum DiffType {
@@ -1025,6 +1032,89 @@ impl GitRepository for RealGitRepository {
})
.boxed()
}
+
+ fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>> {
+ let working_directory = self.working_directory();
+ let git_binary_path = self.git_binary_path.clone();
+ let executor = cx.background_executor().clone();
+ cx.background_spawn(async move {
+ let working_directory = working_directory?;
+ let index_file_path = working_directory.join(".git/index.tmp");
+
+ let delete_temp_index = util::defer({
+ let index_file_path = index_file_path.clone();
+ || {
+ executor
+ .spawn(async move {
+ smol::fs::remove_file(index_file_path).await.log_err();
+ })
+ .detach();
+ }
+ });
+
+ let run_git_command = async |args: &[&str]| {
+ let output = new_smol_command(&git_binary_path)
+ .current_dir(&working_directory)
+ .env("GIT_INDEX_FILE", &index_file_path)
+ .env("GIT_AUTHOR_NAME", "Zed")
+ .env("GIT_AUTHOR_EMAIL", "hi@zed.dev")
+ .env("GIT_COMMITTER_NAME", "Zed")
+ .env("GIT_COMMITTER_EMAIL", "hi@zed.dev")
+ .args(args)
+ .output()
+ .await?;
+ if output.status.success() {
+ anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
+ } else {
+ let error = String::from_utf8_lossy(&output.stderr);
+ Err(anyhow!("Git command failed: {:?}", error))
+ }
+ };
+
+ run_git_command(&["add", "--all"]).await?;
+ let tree = run_git_command(&["write-tree"]).await?;
+ let commit_sha = run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?;
+ let ref_name = Uuid::new_v4().to_string();
+ run_git_command(&["update-ref", &format!("refs/heads/{ref_name}"), &commit_sha])
+ .await?;
+
+ smol::fs::remove_file(index_file_path).await.ok();
+ delete_temp_index.abort();
+
+ commit_sha.parse()
+ })
+ .boxed()
+ }
+
+ fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>> {
+ let working_directory = self.working_directory();
+ let git_binary_path = self.git_binary_path.clone();
+ cx.background_spawn(async move {
+ let working_directory = working_directory?;
+ let index_file_path = working_directory.join(".git/index.tmp");
+
+ let run_git_command = async |args: &[&str]| {
+ let output = new_smol_command(&git_binary_path)
+ .current_dir(&working_directory)
+ .env("GIT_INDEX_FILE", &index_file_path)
+ .args(args)
+ .output()
+ .await?;
+ if output.status.success() {
+ anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
+ } else {
+ let error = String::from_utf8_lossy(&output.stderr);
+ Err(anyhow!("Git command failed: {:?}", error))
+ }
+ };
+
+ run_git_command(&["restore", "--source", &oid.to_string(), "--worktree", "."]).await?;
+ run_git_command(&["read-tree", &oid.to_string()]).await?;
+ run_git_command(&["clean", "-d", "--force"]).await?;
+ Ok(())
+ })
+ .boxed()
+ }
}
async fn run_remote_command(
@@ -1260,29 +1350,72 @@ fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
}
}
-#[test]
-fn test_branches_parsing() {
- // suppress "help: octal escapes are not supported, `\0` is always null"
- #[allow(clippy::octal_escapes)]
- let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n";
- assert_eq!(
- parse_branch_input(&input).unwrap(),
- vec![Branch {
- is_head: true,
- name: "zed-patches".into(),
- upstream: Some(Upstream {
- ref_name: "refs/remotes/origin/zed-patches".into(),
- tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {
- ahead: 0,
- behind: 0
+#[cfg(test)]
+mod tests {
+ use gpui::TestAppContext;
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_checkpoint(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let repo_dir = tempfile::tempdir().unwrap();
+ git2::Repository::init(repo_dir.path()).unwrap();
+ let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap();
+
+ smol::fs::write(repo_dir.path().join("foo"), "foo")
+ .await
+ .unwrap();
+ let checkpoint_sha = repo.checkpoint(cx.to_async()).await.unwrap();
+
+ smol::fs::write(repo_dir.path().join("foo"), "bar")
+ .await
+ .unwrap();
+ smol::fs::write(repo_dir.path().join("baz"), "qux")
+ .await
+ .unwrap();
+ repo.restore_checkpoint(checkpoint_sha, cx.to_async())
+ .await
+ .unwrap();
+ assert_eq!(
+ smol::fs::read_to_string(repo_dir.path().join("foo"))
+ .await
+ .unwrap(),
+ "foo"
+ );
+ assert_eq!(
+ smol::fs::read_to_string(repo_dir.path().join("baz"))
+ .await
+ .ok(),
+ None
+ );
+ }
+
+ #[test]
+ fn test_branches_parsing() {
+ // suppress "help: octal escapes are not supported, `\0` is always null"
+ #[allow(clippy::octal_escapes)]
+ let input = "*\0060964da10574cd9bf06463a53bf6e0769c5c45e\0\0refs/heads/zed-patches\0refs/remotes/origin/zed-patches\0\01733187470\0generated protobuf\n";
+ assert_eq!(
+ parse_branch_input(&input).unwrap(),
+ vec![Branch {
+ is_head: true,
+ name: "zed-patches".into(),
+ upstream: Some(Upstream {
+ ref_name: "refs/remotes/origin/zed-patches".into(),
+ tracking: UpstreamTracking::Tracked(UpstreamTrackingStatus {
+ ahead: 0,
+ behind: 0
+ })
+ }),
+ most_recent_commit: Some(CommitSummary {
+ sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(),
+ subject: "generated protobuf".into(),
+ commit_timestamp: 1733187470,
+ has_parent: false,
})
- }),
- most_recent_commit: Some(CommitSummary {
- sha: "060964da10574cd9bf06463a53bf6e0769c5c45e".into(),
- subject: "generated protobuf".into(),
- commit_timestamp: 1733187470,
- has_parent: false,
- })
- }]
- )
+ }]
+ )
+ }
}
@@ -11,10 +11,10 @@ use collections::HashMap;
use fs::Fs;
use futures::{
channel::{mpsc, oneshot},
- future::{OptionFuture, Shared},
+ future::{self, OptionFuture, Shared},
FutureExt as _, StreamExt as _,
};
-use git::repository::DiffType;
+use git::{repository::DiffType, Oid};
use git::{
repository::{
Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath,
@@ -117,6 +117,16 @@ enum GitStoreState {
},
}
+#[derive(Clone)]
+pub struct GitStoreCheckpoint {
+ checkpoints_by_dot_git_abs_path: HashMap<PathBuf, RepositoryCheckpoint>,
+}
+
+#[derive(Copy, Clone)]
+pub struct RepositoryCheckpoint {
+ sha: Oid,
+}
+
pub struct Repository {
commit_message_buffer: Option<Entity<Buffer>>,
git_store: WeakEntity<GitStore>,
@@ -506,6 +516,45 @@ impl GitStore {
diff_state.read(cx).uncommitted_diff.as_ref()?.upgrade()
}
+ pub fn checkpoint(&self, cx: &App) -> Task<Result<GitStoreCheckpoint>> {
+ let mut dot_git_abs_paths = Vec::new();
+ let mut checkpoints = Vec::new();
+ for repository in self.repositories.values() {
+ let repository = repository.read(cx);
+ dot_git_abs_paths.push(repository.dot_git_abs_path.clone());
+ checkpoints.push(repository.checkpoint().map(|checkpoint| checkpoint?));
+ }
+
+ cx.background_executor().spawn(async move {
+ let checkpoints: Vec<RepositoryCheckpoint> = future::try_join_all(checkpoints).await?;
+ Ok(GitStoreCheckpoint {
+ checkpoints_by_dot_git_abs_path: dot_git_abs_paths
+ .into_iter()
+ .zip(checkpoints)
+ .collect(),
+ })
+ })
+ }
+
+ pub fn restore_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task<Result<()>> {
+ let repositories_by_dot_git_abs_path = self
+ .repositories
+ .values()
+ .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo))
+ .collect::<HashMap<_, _>>();
+
+ let mut tasks = Vec::new();
+ for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
+ if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
+ tasks.push(repository.read(cx).restore_checkpoint(checkpoint));
+ }
+ }
+ cx.background_spawn(async move {
+ future::try_join_all(tasks).await?;
+ Ok(())
+ })
+ }
+
fn downstream_client(&self) -> Option<(AnyProtoClient, ProjectId)> {
match &self.state {
GitStoreState::Local {
@@ -2922,4 +2971,30 @@ impl Repository {
}
})
}
+
+ pub fn checkpoint(&self) -> oneshot::Receiver<Result<RepositoryCheckpoint>> {
+ self.send_job(|repo, cx| async move {
+ match repo {
+ GitRepo::Local(git_repository) => {
+ let sha = git_repository.checkpoint(cx).await?;
+ Ok(RepositoryCheckpoint { sha })
+ }
+ GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
+ }
+ })
+ }
+
+ pub fn restore_checkpoint(
+ &self,
+ checkpoint: RepositoryCheckpoint,
+ ) -> oneshot::Receiver<Result<()>> {
+ self.send_job(move |repo, cx| async move {
+ match repo {
+ GitRepo::Local(git_repository) => {
+ git_repository.restore_checkpoint(checkpoint.sha, cx).await
+ }
+ GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
+ }
+ })
+ }
}