Only reject agent actions, don't restore checkpoint on revert (#37801)

Conrad Irwin created

Updates #37623

Release Notes:

- Changed the behaviour when editing an old message in a native agent
thread.
Prior to this, it would automatically restore the checkpoint (which
could
lead to a surprising amount of work being discarded). Now it will just
reject
any unaccepted agent edits, and you can use the "restore checkpoint"
button
  for the original behavior.

Change summary

crates/acp_thread/src/acp_thread.rs    | 55 +++++++++++++++------------
crates/agent_ui/src/acp/thread_view.rs | 33 +++++++++-------
2 files changed, 49 insertions(+), 39 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1640,13 +1640,13 @@ impl AcpThread {
         cx.foreground_executor().spawn(send_task)
     }
 
-    /// Rewinds this thread to before the entry at `index`, removing it and all
-    /// subsequent entries while reverting any changes made from that point.
-    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
-        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
-            return Task::ready(Err(anyhow!("not supported")));
-        };
-        let Some(message) = self.user_message(&id) else {
+    /// Restores the git working tree to the state at the given checkpoint (if one exists)
+    pub fn restore_checkpoint(
+        &mut self,
+        id: UserMessageId,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let Some((_, message)) = self.user_message_mut(&id) else {
             return Task::ready(Err(anyhow!("message not found")));
         };
 
@@ -1654,15 +1654,30 @@ impl AcpThread {
             .checkpoint
             .as_ref()
             .map(|c| c.git_checkpoint.clone());
-
+        let rewind = self.rewind(id.clone(), cx);
         let git_store = self.project.read(cx).git_store().clone();
-        cx.spawn(async move |this, cx| {
+
+        cx.spawn(async move |_, cx| {
+            rewind.await?;
             if let Some(checkpoint) = checkpoint {
                 git_store
                     .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
                     .await?;
             }
 
+            Ok(())
+        })
+    }
+
+    /// Rewinds this thread to before the entry at `index`, removing it and all
+    /// subsequent entries while rejecting any action_log changes made from that point.
+    /// Unlike `restore_checkpoint`, this method does not restore from git.
+    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
+            return Task::ready(Err(anyhow!("not supported")));
+        };
+
+        cx.spawn(async move |this, cx| {
             cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
             this.update(cx, |this, cx| {
                 if let Some((ix, _)) = this.user_message_mut(&id) {
@@ -1670,7 +1685,11 @@ impl AcpThread {
                     this.entries.truncate(ix);
                     cx.emit(AcpThreadEvent::EntriesRemoved(range));
                 }
-            })
+                this.action_log()
+                    .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
+            })?
+            .await;
+            Ok(())
         })
     }
 
@@ -1727,20 +1746,6 @@ impl AcpThread {
             })
     }
 
-    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
-        self.entries.iter().find_map(|entry| {
-            if let AgentThreadEntry::UserMessage(message) = entry {
-                if message.id.as_ref() == Some(id) {
-                    Some(message)
-                } else {
-                    None
-                }
-            } else {
-                None
-            }
-        })
-    }
-
     fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
         self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
             if let AgentThreadEntry::UserMessage(message) = entry {
@@ -2684,7 +2689,7 @@ mod tests {
                 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
                     panic!("unexpected entries {:?}", thread.entries)
                 };
-                thread.rewind(message.id.clone().unwrap(), cx)
+                thread.restore_checkpoint(message.id.clone().unwrap(), cx)
             })
             .await
             .unwrap();

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -927,7 +927,7 @@ impl AcpThreadView {
                 }
             }
             ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
-                self.regenerate(event.entry_index, editor, window, cx);
+                self.regenerate(event.entry_index, editor.clone(), window, cx);
             }
             ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
                 self.cancel_editing(&Default::default(), window, cx);
@@ -1151,7 +1151,7 @@ impl AcpThreadView {
     fn regenerate(
         &mut self,
         entry_ix: usize,
-        message_editor: &Entity<MessageEditor>,
+        message_editor: Entity<MessageEditor>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
@@ -1168,16 +1168,18 @@ impl AcpThreadView {
             return;
         };
 
-        let contents = message_editor.update(cx, |message_editor, cx| message_editor.contents(cx));
-
-        let task = cx.spawn(async move |_, cx| {
-            let contents = contents.await?;
+        cx.spawn_in(window, async move |this, cx| {
             thread
                 .update(cx, |thread, cx| thread.rewind(user_message_id, cx))?
                 .await?;
-            Ok(contents)
-        });
-        self.send_impl(task, window, cx);
+            let contents =
+                message_editor.update(cx, |message_editor, cx| message_editor.contents(cx))?;
+            this.update_in(cx, |this, window, cx| {
+                this.send_impl(contents, window, cx);
+            })?;
+            anyhow::Ok(())
+        })
+        .detach();
     }
 
     fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context<Self>) {
@@ -1635,14 +1637,16 @@ impl AcpThreadView {
         cx.notify();
     }
 
-    fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) {
+    fn restore_checkpoint(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) {
         let Some(thread) = self.thread() else {
             return;
         };
+
         thread
-            .update(cx, |thread, cx| thread.rewind(message_id.clone(), cx))
+            .update(cx, |thread, cx| {
+                thread.restore_checkpoint(message_id.clone(), cx)
+            })
             .detach_and_log_err(cx);
-        cx.notify();
     }
 
     fn render_entry(
@@ -1712,8 +1716,9 @@ impl AcpThreadView {
                                         .label_size(LabelSize::XSmall)
                                         .icon_color(Color::Muted)
                                         .color(Color::Muted)
+                                        .tooltip(Tooltip::text("Restores all files in the project to the content they had at this point in the conversation."))
                                         .on_click(cx.listener(move |this, _, _window, cx| {
-                                            this.rewind(&message_id, cx);
+                                            this.restore_checkpoint(&message_id, cx);
                                         }))
                                 )
                                 .child(Divider::horizontal())
@@ -1784,7 +1789,7 @@ impl AcpThreadView {
                                                             let editor = editor.clone();
                                                             move |this, _, window, cx| {
                                                                 this.regenerate(
-                                                                    entry_ix, &editor, window, cx,
+                                                                    entry_ix, editor.clone(), window, cx,
                                                                 );
                                                             }
                                                         })).into_any_element()