Miscellaneous UX fixes for agent2 (#36591)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs    | 97 ++++++++++++++++++++++++++++
crates/agent_ui/src/acp/thread_view.rs | 44 +++++++-----
2 files changed, 123 insertions(+), 18 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1394,6 +1394,17 @@ impl AcpThread {
                             this.send_task.take();
                         }
 
+                        // Truncate entries if the last prompt was refused.
+                        if let Ok(Ok(acp::PromptResponse {
+                            stop_reason: acp::StopReason::Refusal,
+                        })) = result
+                            && let Some((ix, _)) = this.last_user_message()
+                        {
+                            let range = ix..this.entries.len();
+                            this.entries.truncate(ix);
+                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
+                        }
+
                         cx.emit(AcpThreadEvent::Stopped);
                         Ok(())
                     }
@@ -2369,6 +2380,92 @@ mod tests {
         assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
     }
 
+    #[gpui::test]
+    async fn test_refusal(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/"), json!({})).await;
+        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
+
+        let refuse_next = Arc::new(AtomicBool::new(false));
+        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+            let refuse_next = refuse_next.clone();
+            move |request, thread, mut cx| {
+                let refuse_next = refuse_next.clone();
+                async move {
+                    if refuse_next.load(SeqCst) {
+                        return Ok(acp::PromptResponse {
+                            stop_reason: acp::StopReason::Refusal,
+                        });
+                    }
+
+                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
+                        panic!("expected text content block");
+                    };
+                    thread.update(&mut cx, |thread, cx| {
+                        thread
+                            .handle_session_update(
+                                acp::SessionUpdate::AgentMessageChunk {
+                                    content: content.text.to_uppercase().into(),
+                                },
+                                cx,
+                            )
+                            .unwrap();
+                    })?;
+                    Ok(acp::PromptResponse {
+                        stop_reason: acp::StopReason::EndTurn,
+                    })
+                }
+                .boxed_local()
+            }
+        }));
+        let thread = cx
+            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User
+
+                    hello
+
+                    ## Assistant
+
+                    HELLO
+
+                "}
+            );
+        });
+
+        // Simulate refusing the second message, ensuring the conversation gets
+        // truncated to before sending it.
+        refuse_next.store(true, SeqCst);
+        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User
+
+                    hello
+
+                    ## Assistant
+
+                    HELLO
+
+                "}
+            );
+        });
+    }
+
     async fn run_until_first_tool_call(
         thread: &Entity<AcpThread>,
         cx: &mut TestAppContext,

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -2398,7 +2398,6 @@ impl AcpThreadView {
             })
             .when(!changed_buffers.is_empty(), |this| {
                 this.child(self.render_edits_summary(
-                    action_log,
                     &changed_buffers,
                     self.edits_expanded,
                     pending_edits,
@@ -2550,7 +2549,6 @@ impl AcpThreadView {
 
     fn render_edits_summary(
         &self,
-        action_log: &Entity<ActionLog>,
         changed_buffers: &BTreeMap<Entity<Buffer>, Entity<BufferDiff>>,
         expanded: bool,
         pending_edits: bool,
@@ -2661,14 +2659,9 @@ impl AcpThreadView {
                                 )
                                 .map(|kb| kb.size(rems_from_px(10.))),
                             )
-                            .on_click({
-                                let action_log = action_log.clone();
-                                cx.listener(move |_, _, _, cx| {
-                                    action_log.update(cx, |action_log, cx| {
-                                        action_log.reject_all_edits(cx).detach();
-                                    })
-                                })
-                            }),
+                            .on_click(cx.listener(move |this, _, window, cx| {
+                                this.reject_all(&RejectAll, window, cx);
+                            })),
                     )
                     .child(
                         Button::new("keep-all-changes", "Keep All")
@@ -2681,14 +2674,9 @@ impl AcpThreadView {
                                 KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
                                     .map(|kb| kb.size(rems_from_px(10.))),
                             )
-                            .on_click({
-                                let action_log = action_log.clone();
-                                cx.listener(move |_, _, _, cx| {
-                                    action_log.update(cx, |action_log, cx| {
-                                        action_log.keep_all_edits(cx);
-                                    })
-                                })
-                            }),
+                            .on_click(cx.listener(move |this, _, window, cx| {
+                                this.keep_all(&KeepAll, window, cx);
+                            })),
                     ),
             )
     }
@@ -3014,6 +3002,24 @@ impl AcpThreadView {
         });
     }
 
+    fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
+        let Some(thread) = self.thread() else {
+            return;
+        };
+        let action_log = thread.read(cx).action_log().clone();
+        action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
+    }
+
+    fn reject_all(&mut self, _: &RejectAll, _window: &mut Window, cx: &mut Context<Self>) {
+        let Some(thread) = self.thread() else {
+            return;
+        };
+        let action_log = thread.read(cx).action_log().clone();
+        action_log
+            .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
+            .detach();
+    }
+
     fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
         let thread = self.as_native_thread(cx)?.read(cx);
 
@@ -3952,6 +3958,8 @@ impl Render for AcpThreadView {
             .key_context("AcpThread")
             .on_action(cx.listener(Self::open_agent_diff))
             .on_action(cx.listener(Self::toggle_burn_mode))
+            .on_action(cx.listener(Self::keep_all))
+            .on_action(cx.listener(Self::reject_all))
             .bg(cx.theme().colors().panel_background)
             .child(match &self.thread_state {
                 ThreadState::Unauthenticated {