Delete unused checkpoints (#27260)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/assistant2/src/message_editor.rs |   7 +
crates/assistant2/src/thread.rs         | 107 +++++++++++++++-----------
crates/assistant_eval/src/eval.rs       |   2 
crates/fs/src/fake_git_repo.rs          |   8 ++
crates/git/src/repository.rs            |  83 ++++++++++++++++----
crates/project/src/git_store.rs         |  37 ++++++++
6 files changed, 177 insertions(+), 67 deletions(-)

Detailed changes

crates/assistant2/src/message_editor.rs 🔗

@@ -12,6 +12,7 @@ use gpui::{
 };
 use language_model::LanguageModelRegistry;
 use language_model_selector::ToggleModelSelector;
+use project::Project;
 use rope::Point;
 use settings::Settings;
 use std::time::Duration;
@@ -37,6 +38,7 @@ pub struct MessageEditor {
     editor: Entity<Editor>,
     #[allow(dead_code)]
     workspace: WeakEntity<Workspace>,
+    project: Entity<Project>,
     context_store: Entity<ContextStore>,
     context_strip: Entity<ContextStrip>,
     context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@@ -107,6 +109,7 @@ impl MessageEditor {
 
         Self {
             editor: editor.clone(),
+            project: thread.read(cx).project().clone(),
             thread,
             workspace,
             context_store,
@@ -205,7 +208,9 @@ impl MessageEditor {
 
         let thread = self.thread.clone();
         let context_store = self.context_store.clone();
+        let checkpoint = self.project.read(cx).git_store().read(cx).checkpoint(cx);
         cx.spawn(async move |_, cx| {
+            let checkpoint = checkpoint.await.ok();
             refresh_task.await;
             let (system_prompt_context, load_error) = system_prompt_context_task.await;
             thread
@@ -219,7 +224,7 @@ impl MessageEditor {
             thread
                 .update(cx, |thread, cx| {
                     let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
-                    thread.insert_user_message(user_message, context, cx);
+                    thread.insert_user_message(user_message, context, checkpoint, cx);
                     thread.send_to_model(model, request_kind, cx);
                 })
                 .ok();

crates/assistant2/src/thread.rs 🔗

@@ -1,6 +1,5 @@
 use std::fmt::Write as _;
 use std::io::Write;
-use std::mem;
 use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
@@ -186,8 +185,7 @@ pub struct Thread {
     tool_use: ToolUseState,
     action_log: Entity<ActionLog>,
     last_restore_checkpoint: Option<LastRestoreCheckpoint>,
-    pending_checkpoint: Option<Task<Result<ThreadCheckpoint>>>,
-    checkpoint_on_next_user_message: bool,
+    pending_checkpoint: Option<ThreadCheckpoint>,
     scripting_session: Entity<ScriptingSession>,
     scripting_tool_use: ToolUseState,
     initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
@@ -220,7 +218,6 @@ impl Thread {
             tools: tools.clone(),
             last_restore_checkpoint: None,
             pending_checkpoint: None,
-            checkpoint_on_next_user_message: true,
             tool_use: ToolUseState::new(tools.clone()),
             scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
             scripting_tool_use: ToolUseState::new(tools),
@@ -293,7 +290,6 @@ impl Thread {
             pending_completions: Vec::new(),
             last_restore_checkpoint: None,
             pending_checkpoint: None,
-            checkpoint_on_next_user_message: true,
             project,
             prompt_builder,
             tools,
@@ -385,11 +381,8 @@ impl Thread {
                 } else {
                     this.truncate(checkpoint.message_id, cx);
                     this.last_restore_checkpoint = None;
-                    this.pending_checkpoint = Some(Task::ready(Ok(ThreadCheckpoint {
-                        message_id: this.next_message_id,
-                        git_checkpoint: checkpoint.git_checkpoint,
-                    })));
                 }
+                this.pending_checkpoint = None;
                 cx.emit(ThreadEvent::CheckpointChanged);
                 cx.notify();
             })?;
@@ -397,46 +390,62 @@ impl Thread {
         })
     }
 
-    fn checkpoint(&mut self, cx: &mut Context<Self>) {
-        if self.is_generating() {
+    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
+        let pending_checkpoint = if self.is_generating() {
             return;
-        }
+        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
+            checkpoint
+        } else {
+            return;
+        };
 
         let git_store = self.project.read(cx).git_store().clone();
-        let new_checkpoint = git_store.read(cx).checkpoint(cx);
-        let old_checkpoint = self.pending_checkpoint.take();
-        let next_user_message_id = self.next_message_id;
-        self.pending_checkpoint = Some(cx.spawn(async move |this, cx| {
-            let new_checkpoint = new_checkpoint.await?;
-
-            if let Some(old_checkpoint) = old_checkpoint {
-                if let Ok(old_checkpoint) = old_checkpoint.await {
-                    let equal = git_store
+        let final_checkpoint = git_store.read(cx).checkpoint(cx);
+        cx.spawn(async move |this, cx| match final_checkpoint.await {
+            Ok(final_checkpoint) => {
+                let equal = git_store
+                    .read_with(cx, |store, cx| {
+                        store.compare_checkpoints(
+                            pending_checkpoint.git_checkpoint.clone(),
+                            final_checkpoint.clone(),
+                            cx,
+                        )
+                    })?
+                    .await
+                    .unwrap_or(false);
+
+                if equal {
+                    git_store
                         .read_with(cx, |store, cx| {
-                            store.compare_checkpoints(
-                                old_checkpoint.git_checkpoint.clone(),
-                                new_checkpoint.clone(),
-                                cx,
-                            )
+                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
                         })?
-                        .await;
-
-                    if equal.ok() != Some(true) {
-                        this.update(cx, |this, cx| {
-                            this.checkpoints_by_message
-                                .insert(old_checkpoint.message_id, old_checkpoint);
-                            cx.emit(ThreadEvent::CheckpointChanged);
-                            cx.notify();
-                        })?;
-                    }
+                        .detach();
+                } else {
+                    this.update(cx, |this, cx| {
+                        this.insert_checkpoint(pending_checkpoint, cx)
+                    })?;
                 }
+
+                git_store
+                    .read_with(cx, |store, cx| {
+                        store.delete_checkpoint(final_checkpoint, cx)
+                    })?
+                    .detach();
+
+                Ok(())
             }
+            Err(_) => this.update(cx, |this, cx| {
+                this.insert_checkpoint(pending_checkpoint, cx)
+            }),
+        })
+        .detach();
+    }
 
-            Ok(ThreadCheckpoint {
-                message_id: next_user_message_id,
-                git_checkpoint: new_checkpoint,
-            })
-        }));
+    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
+        self.checkpoints_by_message
+            .insert(checkpoint.message_id, checkpoint);
+        cx.emit(ThreadEvent::CheckpointChanged);
+        cx.notify();
     }
 
     pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
@@ -517,18 +526,21 @@ impl Thread {
         &mut self,
         text: impl Into<String>,
         context: Vec<ContextSnapshot>,
+        git_checkpoint: Option<GitStoreCheckpoint>,
         cx: &mut Context<Self>,
     ) -> MessageId {
-        if mem::take(&mut self.checkpoint_on_next_user_message) {
-            self.checkpoint(cx);
-        }
-
         let message_id =
             self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
         let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
         self.context
             .extend(context.into_iter().map(|context| (context.id, context)));
         self.context_by_message.insert(message_id, context_ids);
+        if let Some(git_checkpoint) = git_checkpoint {
+            self.pending_checkpoint = Some(ThreadCheckpoint {
+                message_id,
+                git_checkpoint,
+            });
+        }
         message_id
     }
 
@@ -1050,7 +1062,7 @@ impl Thread {
 
             thread
                 .update(cx, |thread, cx| {
-                    thread.checkpoint(cx);
+                    thread.finalize_pending_checkpoint(cx);
                     match result.as_ref() {
                         Ok(stop_reason) => match stop_reason {
                             StopReason::ToolUse => {
@@ -1319,6 +1331,7 @@ impl Thread {
             // so for now we provide some text to keep the model on track.
             "Here are the tool results.",
             Vec::new(),
+            None,
             cx,
         );
     }
@@ -1341,7 +1354,7 @@ impl Thread {
             }
             canceled
         };
-        self.checkpoint(cx);
+        self.finalize_pending_checkpoint(cx);
         canceled
     }
 

crates/assistant_eval/src/eval.rs 🔗

@@ -96,7 +96,7 @@ impl Eval {
             assistant.update(cx, |assistant, cx| {
                 assistant.thread.update(cx, |thread, cx| {
                     let context = vec![];
-                    thread.insert_user_message(self.user_prompt.clone(), context, cx);
+                    thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
                     thread.set_system_prompt_context(system_prompt_context);
                     thread.send_to_model(model, RequestKind::Chat, cx);
                 });

crates/fs/src/fake_git_repo.rs 🔗

@@ -429,4 +429,12 @@ impl GitRepository for FakeGitRepository {
     ) -> BoxFuture<Result<bool>> {
         unimplemented!()
     }
+
+    fn delete_checkpoint(
+        &self,
+        _checkpoint: GitRepositoryCheckpoint,
+        _cx: AsyncApp,
+    ) -> BoxFuture<Result<()>> {
+        unimplemented!()
+    }
 }

crates/git/src/repository.rs 🔗

@@ -308,6 +308,13 @@ pub trait GitRepository: Send + Sync {
         right: GitRepositoryCheckpoint,
         cx: AsyncApp,
     ) -> BoxFuture<Result<bool>>;
+
+    /// Deletes a previously-created checkpoint.
+    fn delete_checkpoint(
+        &self,
+        checkpoint: GitRepositoryCheckpoint,
+        cx: AsyncApp,
+    ) -> BoxFuture<Result<()>>;
 }
 
 pub enum DiffType {
@@ -351,10 +358,11 @@ impl RealGitRepository {
     }
 }
 
-#[derive(Copy, Clone)]
+#[derive(Clone, Debug)]
 pub struct GitRepositoryCheckpoint {
+    ref_name: String,
     head_sha: Option<Oid>,
-    sha: Oid,
+    commit_sha: Oid,
 }
 
 // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects
@@ -1071,21 +1079,17 @@ impl GitRepository for RealGitRepository {
                 } else {
                     git.run(&["commit-tree", &tree, "-m", "Checkpoint"]).await?
                 };
-                let ref_name = Uuid::new_v4().to_string();
-                git.run(&[
-                    "update-ref",
-                    &format!("refs/zed/{ref_name}"),
-                    &checkpoint_sha,
-                ])
-                .await?;
+                let ref_name = format!("refs/zed/{}", Uuid::new_v4());
+                git.run(&["update-ref", &ref_name, &checkpoint_sha]).await?;
 
                 Ok(GitRepositoryCheckpoint {
+                    ref_name,
                     head_sha: if let Some(head_sha) = head_sha {
                         Some(head_sha.parse()?)
                     } else {
                         None
                     },
-                    sha: checkpoint_sha.parse()?,
+                    commit_sha: checkpoint_sha.parse()?,
                 })
             })
             .await
@@ -1109,14 +1113,15 @@ impl GitRepository for RealGitRepository {
             git.run(&[
                 "restore",
                 "--source",
-                &checkpoint.sha.to_string(),
+                &checkpoint.commit_sha.to_string(),
                 "--worktree",
                 ".",
             ])
             .await?;
 
             git.with_temp_index(async move |git| {
-                git.run(&["read-tree", &checkpoint.sha.to_string()]).await?;
+                git.run(&["read-tree", &checkpoint.commit_sha.to_string()])
+                    .await?;
                 git.run(&["clean", "-d", "--force"]).await
             })
             .await?;
@@ -1154,8 +1159,8 @@ impl GitRepository for RealGitRepository {
                 .run(&[
                     "diff-tree",
                     "--quiet",
-                    &left.sha.to_string(),
-                    &right.sha.to_string(),
+                    &left.commit_sha.to_string(),
+                    &right.commit_sha.to_string(),
                 ])
                 .await;
             match result {
@@ -1175,6 +1180,24 @@ impl GitRepository for RealGitRepository {
         })
         .boxed()
     }
+
+    fn delete_checkpoint(
+        &self,
+        checkpoint: GitRepositoryCheckpoint,
+        cx: AsyncApp,
+    ) -> BoxFuture<Result<()>> {
+        let working_directory = self.working_directory();
+        let git_binary_path = self.git_binary_path.clone();
+
+        let executor = cx.background_executor().clone();
+        cx.background_spawn(async move {
+            let working_directory = working_directory?;
+            let git = GitBinary::new(git_binary_path, working_directory, executor);
+            git.run(&["update-ref", "-d", &checkpoint.ref_name]).await?;
+            Ok(())
+        })
+        .boxed()
+    }
 }
 
 struct GitBinary {
@@ -1574,7 +1597,9 @@ mod tests {
             .await
             .unwrap();
 
-        repo.restore_checkpoint(checkpoint, cx.to_async())
+        // Ensure checkpoint stays alive even after a Git GC.
+        repo.gc(cx.to_async()).await.unwrap();
+        repo.restore_checkpoint(checkpoint.clone(), cx.to_async())
             .await
             .unwrap();
 
@@ -1595,6 +1620,15 @@ mod tests {
                 .ok(),
             None
         );
+
+        // Garbage collecting after deleting a checkpoint makes it unreachable.
+        repo.delete_checkpoint(checkpoint.clone(), cx.to_async())
+            .await
+            .unwrap();
+        repo.gc(cx.to_async()).await.unwrap();
+        repo.restore_checkpoint(checkpoint.clone(), cx.to_async())
+            .await
+            .unwrap_err();
     }
 
     #[gpui::test]
@@ -1737,7 +1771,7 @@ mod tests {
         let checkpoint2 = repo.checkpoint(cx.to_async()).await.unwrap();
 
         assert!(!repo
-            .compare_checkpoints(checkpoint1, checkpoint2, cx.to_async())
+            .compare_checkpoints(checkpoint1, checkpoint2.clone(), cx.to_async())
             .await
             .unwrap());
 
@@ -1774,4 +1808,21 @@ mod tests {
             }]
         )
     }
+
+    impl RealGitRepository {
+        /// Force a Git garbage collection on the repository.
+        fn gc(&self, cx: AsyncApp) -> BoxFuture<Result<()>> {
+            let working_directory = self.working_directory();
+            let git_binary_path = self.git_binary_path.clone();
+            let executor = cx.background_executor().clone();
+            cx.background_spawn(async move {
+                let git_binary_path = git_binary_path.clone();
+                let working_directory = working_directory?;
+                let git = GitBinary::new(git_binary_path, working_directory, executor);
+                git.run(&["gc", "--prune=now"]).await?;
+                Ok(())
+            })
+            .boxed()
+        }
+    }
 }

crates/project/src/git_store.rs 🔗

@@ -549,8 +549,7 @@ impl GitStore {
         }
 
         cx.background_executor().spawn(async move {
-            let checkpoints: Vec<GitRepositoryCheckpoint> =
-                future::try_join_all(checkpoints).await?;
+            let checkpoints = future::try_join_all(checkpoints).await?;
             Ok(GitStoreCheckpoint {
                 checkpoints_by_dot_git_abs_path: dot_git_abs_paths
                     .into_iter()
@@ -617,6 +616,26 @@ impl GitStore {
         })
     }
 
+    pub fn delete_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task<Result<()>> {
+        let repositories_by_dot_git_abs_path = self
+            .repositories
+            .values()
+            .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo))
+            .collect::<HashMap<_, _>>();
+
+        let mut tasks = Vec::new();
+        for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
+            if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
+                let delete = repository.read(cx).delete_checkpoint(checkpoint);
+                tasks.push(async move { delete.await? });
+            }
+        }
+        cx.background_spawn(async move {
+            future::try_join_all(tasks).await?;
+            Ok(())
+        })
+    }
+
     /// Blames a buffer.
     pub fn blame_buffer(
         &self,
@@ -3319,6 +3338,20 @@ impl Repository {
             }
         })
     }
+
+    pub fn delete_checkpoint(
+        &self,
+        checkpoint: GitRepositoryCheckpoint,
+    ) -> oneshot::Receiver<Result<()>> {
+        self.send_job(move |repo, cx| async move {
+            match repo {
+                RepositoryState::Local(git_repository) => {
+                    git_repository.delete_checkpoint(checkpoint, cx).await
+                }
+                RepositoryState::Remote { .. } => Err(anyhow!("not implemented yet")),
+            }
+        })
+    }
 }
 
 fn get_permalink_in_rust_registry_src(