diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 8afa466bb607c02b7cdfe795b3168c2e20a0ba10..f36ed6e7a04876dcb057f87889c6c224934681bc 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1640,13 +1640,13 @@ impl AcpThread { cx.foreground_executor().spawn(send_task) } - /// Rewinds this thread to before the entry at `index`, removing it and all - /// subsequent entries while reverting any changes made from that point. - pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { - let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { - return Task::ready(Err(anyhow!("not supported"))); - }; - let Some(message) = self.user_message(&id) else { + /// Restores the git working tree to the state at the given checkpoint (if one exists) + pub fn restore_checkpoint( + &mut self, + id: UserMessageId, + cx: &mut Context, + ) -> Task> { + let Some((_, message)) = self.user_message_mut(&id) else { return Task::ready(Err(anyhow!("message not found"))); }; @@ -1654,15 +1654,30 @@ impl AcpThread { .checkpoint .as_ref() .map(|c| c.git_checkpoint.clone()); - + let rewind = self.rewind(id.clone(), cx); let git_store = self.project.read(cx).git_store().clone(); - cx.spawn(async move |this, cx| { + + cx.spawn(async move |_, cx| { + rewind.await?; if let Some(checkpoint) = checkpoint { git_store .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))? .await?; } + Ok(()) + }) + } + + /// Rewinds this thread to before the entry at `index`, removing it and all + /// subsequent entries while rejecting any action_log changes made from that point. + /// Unlike `restore_checkpoint`, this method does not restore from git. + pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { + let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { + return Task::ready(Err(anyhow!("not supported"))); + }; + + cx.spawn(async move |this, cx| { cx.update(|cx| truncate.run(id.clone(), cx))?.await?; this.update(cx, |this, cx| { if let Some((ix, _)) = this.user_message_mut(&id) { @@ -1670,7 +1685,11 @@ impl AcpThread { this.entries.truncate(ix); cx.emit(AcpThreadEvent::EntriesRemoved(range)); } - }) + this.action_log() + .update(cx, |action_log, cx| action_log.reject_all_edits(cx)) + })? + .await; + Ok(()) }) } @@ -1727,20 +1746,6 @@ impl AcpThread { }) } - fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { - self.entries.iter().find_map(|entry| { - if let AgentThreadEntry::UserMessage(message) = entry { - if message.id.as_ref() == Some(id) { - Some(message) - } else { - None - } - } else { - None - } - }) - } - fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { if let AgentThreadEntry::UserMessage(message) = entry { @@ -2684,7 +2689,7 @@ mod tests { let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { panic!("unexpected entries {:?}", thread.entries) }; - thread.rewind(message.id.clone().unwrap(), cx) + thread.restore_checkpoint(message.id.clone().unwrap(), cx) }) .await .unwrap(); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e9794160a2c6facc4f5a9aacf700aae8b1a1eb72..8627455b4f33029152dc09e63ae868506defb430 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -927,7 +927,7 @@ impl AcpThreadView { } } ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => { - self.regenerate(event.entry_index, editor, window, cx); + self.regenerate(event.entry_index, editor.clone(), window, cx); } ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => { self.cancel_editing(&Default::default(), window, cx); @@ -1151,7 +1151,7 @@ impl AcpThreadView { fn regenerate( &mut self, entry_ix: usize, - message_editor: &Entity, + message_editor: Entity, window: &mut Window, cx: &mut Context, ) { @@ -1168,16 +1168,18 @@ impl AcpThreadView { return; }; - let contents = message_editor.update(cx, |message_editor, cx| message_editor.contents(cx)); - - let task = cx.spawn(async move |_, cx| { - let contents = contents.await?; + cx.spawn_in(window, async move |this, cx| { thread .update(cx, |thread, cx| thread.rewind(user_message_id, cx))? .await?; - Ok(contents) - }); - self.send_impl(task, window, cx); + let contents = + message_editor.update(cx, |message_editor, cx| message_editor.contents(cx))?; + this.update_in(cx, |this, window, cx| { + this.send_impl(contents, window, cx); + })?; + anyhow::Ok(()) + }) + .detach(); } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -1635,14 +1637,16 @@ impl AcpThreadView { cx.notify(); } - fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context) { + fn restore_checkpoint(&mut self, message_id: &UserMessageId, cx: &mut Context) { let Some(thread) = self.thread() else { return; }; + thread - .update(cx, |thread, cx| thread.rewind(message_id.clone(), cx)) + .update(cx, |thread, cx| { + thread.restore_checkpoint(message_id.clone(), cx) + }) .detach_and_log_err(cx); - cx.notify(); } fn render_entry( @@ -1712,8 +1716,9 @@ impl AcpThreadView { .label_size(LabelSize::XSmall) .icon_color(Color::Muted) .color(Color::Muted) + .tooltip(Tooltip::text("Restores all files in the project to the content they had at this point in the conversation.")) .on_click(cx.listener(move |this, _, _window, cx| { - this.rewind(&message_id, cx); + this.restore_checkpoint(&message_id, cx); })) ) .child(Divider::horizontal()) @@ -1784,7 +1789,7 @@ impl AcpThreadView { let editor = editor.clone(); move |this, _, window, cx| { this.regenerate( - entry_ix, &editor, window, cx, + entry_ix, editor.clone(), window, cx, ); } })).into_any_element()