Add UI feedback for checkpoint restoration (#27203)

Antonio Scandurra , Agus Zubiaga , and Bennet Bo Fenner created

Release Notes:

- N/A

Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

crates/assistant2/src/active_thread.rs | 86 +++++++++++++++++++++------
crates/assistant2/src/thread.rs        | 48 ++++++++++++++
crates/project/src/git.rs              |  3 
3 files changed, 113 insertions(+), 24 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -1,14 +1,16 @@
-use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
+use crate::thread::{
+    LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent,
+};
 use crate::thread_store::ThreadStore;
 use crate::tool_use::{ToolUse, ToolUseStatus};
 use crate::ui::ContextPill;
 use collections::HashMap;
 use editor::{Editor, MultiBuffer};
 use gpui::{
-    list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent,
-    DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset,
-    ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation,
-    UnderlineStyle, WeakEntity,
+    list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
+    ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
+    ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
+    Transformation, UnderlineStyle, WeakEntity,
 };
 use language::{Buffer, LanguageRegistry};
 use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@@ -18,7 +20,7 @@ use settings::Settings as _;
 use std::sync::Arc;
 use std::time::Duration;
 use theme::ThemeSettings;
-use ui::{prelude::*, Disclosure, KeyBinding};
+use ui::{prelude::*, Disclosure, KeyBinding, Tooltip};
 use util::ResultExt as _;
 use workspace::{OpenOptions, Workspace};
 
@@ -401,7 +403,6 @@ impl ActiveThread {
                         window,
                         cx,
                     );
-
                     self.render_scripting_tool_use_markdown(
                         tool_use.id.clone(),
                         tool_use.name.as_ref(),
@@ -463,6 +464,7 @@ impl ActiveThread {
                     }
                 }
             }
+            ThreadEvent::CheckpointChanged => cx.notify(),
         }
     }
 
@@ -789,20 +791,62 @@ impl ActiveThread {
         v_flex()
             .when(ix == 0, |parent| parent.child(self.render_rules_item(cx)))
             .when_some(checkpoint, |parent, checkpoint| {
-                parent.child(
-                    h_flex().pl_2().child(
-                        Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
-                            .icon(IconName::Undo)
-                            .size(ButtonSize::Compact)
-                            .on_click(cx.listener(move |this, _, _window, cx| {
-                                this.thread.update(cx, |thread, cx| {
-                                    thread
-                                        .restore_checkpoint(checkpoint.clone(), cx)
-                                        .detach_and_log_err(cx);
-                                });
-                            })),
-                    ),
-                )
+                let mut is_pending = false;
+                let mut error = None;
+                if let Some(last_restore_checkpoint) =
+                    self.thread.read(cx).last_restore_checkpoint()
+                {
+                    if last_restore_checkpoint.message_id() == message_id {
+                        match last_restore_checkpoint {
+                            LastRestoreCheckpoint::Pending { .. } => is_pending = true,
+                            LastRestoreCheckpoint::Error { error: err, .. } => {
+                                error = Some(err.clone());
+                            }
+                        }
+                    }
+                }
+
+                let restore_checkpoint_button =
+                    Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
+                        .icon(if error.is_some() {
+                            IconName::XCircle
+                        } else {
+                            IconName::Undo
+                        })
+                        .size(ButtonSize::Compact)
+                        .disabled(is_pending)
+                        .icon_color(if error.is_some() {
+                            Some(Color::Error)
+                        } else {
+                            None
+                        })
+                        .on_click(cx.listener(move |this, _, _window, cx| {
+                            this.thread.update(cx, |thread, cx| {
+                                thread
+                                    .restore_checkpoint(checkpoint.clone(), cx)
+                                    .detach_and_log_err(cx);
+                            });
+                        }));
+
+                let restore_checkpoint_button = if is_pending {
+                    restore_checkpoint_button
+                        .with_animation(
+                            ("pulsating-restore-checkpoint-button", ix),
+                            Animation::new(Duration::from_secs(2))
+                                .repeat()
+                                .with_easing(pulsating_between(0.6, 1.)),
+                            |label, delta| label.alpha(delta),
+                        )
+                        .into_any_element()
+                } else if let Some(error) = error {
+                    restore_checkpoint_button
+                        .tooltip(Tooltip::text(error.to_string()))
+                        .into_any_element()
+                } else {
+                    restore_checkpoint_button.into_any_element()
+                };
+
+                parent.child(h_flex().pl_2().child(restore_checkpoint_button))
             })
             .child(styled_message)
             .into_any()

crates/assistant2/src/thread.rs 🔗

@@ -99,6 +99,25 @@ pub struct ThreadCheckpoint {
     git_checkpoint: GitStoreCheckpoint,
 }
 
+pub enum LastRestoreCheckpoint {
+    Pending {
+        message_id: MessageId,
+    },
+    Error {
+        message_id: MessageId,
+        error: String,
+    },
+}
+
+impl LastRestoreCheckpoint {
+    pub fn message_id(&self) -> MessageId {
+        match self {
+            LastRestoreCheckpoint::Pending { message_id } => *message_id,
+            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
+        }
+    }
+}
+
 /// A thread of conversation with the LLM.
 pub struct Thread {
     id: ThreadId,
@@ -118,6 +137,7 @@ pub struct Thread {
     tools: Arc<ToolWorkingSet>,
     tool_use: ToolUseState,
     action_log: Entity<ActionLog>,
+    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
     scripting_session: Entity<ScriptingSession>,
     scripting_tool_use: ToolUseState,
     initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
@@ -147,6 +167,7 @@ impl Thread {
             project: project.clone(),
             prompt_builder,
             tools: tools.clone(),
+            last_restore_checkpoint: None,
             tool_use: ToolUseState::new(tools.clone()),
             scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
             scripting_tool_use: ToolUseState::new(tools),
@@ -207,6 +228,7 @@ impl Thread {
             checkpoints_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
+            last_restore_checkpoint: None,
             project,
             prompt_builder,
             tools,
@@ -279,17 +301,38 @@ impl Thread {
         checkpoint: ThreadCheckpoint,
         cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
+        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
+            message_id: checkpoint.message_id,
+        });
+        cx.emit(ThreadEvent::CheckpointChanged);
+
         let project = self.project.read(cx);
         let restore = project
             .git_store()
             .read(cx)
             .restore_checkpoint(checkpoint.git_checkpoint, cx);
         cx.spawn(async move |this, cx| {
-            restore.await?;
-            this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
+            let result = restore.await;
+            this.update(cx, |this, cx| {
+                if let Err(err) = result.as_ref() {
+                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
+                        message_id: checkpoint.message_id,
+                        error: err.to_string(),
+                    });
+                } else {
+                    this.last_restore_checkpoint = None;
+                    this.truncate(checkpoint.message_id, cx);
+                }
+                cx.emit(ThreadEvent::CheckpointChanged);
+            })?;
+            result
         })
     }
 
+    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
+        self.last_restore_checkpoint.as_ref()
+    }
+
     pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
         let Some(message_ix) = self
             .messages
@@ -1361,6 +1404,7 @@ pub enum ThreadEvent {
         /// Whether the tool was canceled by the user.
         canceled: bool,
     },
+    CheckpointChanged,
 }
 
 impl EventEmitter<ThreadEvent> for Thread {}

crates/project/src/git.rs 🔗

@@ -542,7 +542,8 @@ impl GitStore {
         let mut tasks = Vec::new();
         for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
             if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
-                tasks.push(repository.read(cx).restore_checkpoint(checkpoint));
+                let restore = repository.read(cx).restore_checkpoint(checkpoint);
+                tasks.push(async move { restore.await? });
             }
         }
         cx.background_spawn(async move {