@@ -1394,6 +1394,17 @@ impl AcpThread {
this.send_task.take();
}
+ // Truncate entries if the last prompt was refused.
+ if let Ok(Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::Refusal,
+ })) = result
+ && let Some((ix, _)) = this.last_user_message()
+ {
+ let range = ix..this.entries.len();
+ this.entries.truncate(ix);
+ cx.emit(AcpThreadEvent::EntriesRemoved(range));
+ }
+
cx.emit(AcpThreadEvent::Stopped);
Ok(())
}
@@ -2369,6 +2380,92 @@ mod tests {
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
}
+ #[gpui::test]
+ async fn test_refusal(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(path!("/"), json!({})).await;
+ let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
+
+ let refuse_next = Arc::new(AtomicBool::new(false));
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+ let refuse_next = refuse_next.clone();
+ move |request, thread, mut cx| {
+ let refuse_next = refuse_next.clone();
+ async move {
+ if refuse_next.load(SeqCst) {
+ return Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::Refusal,
+ });
+ }
+
+ let acp::ContentBlock::Text(content) = &request.prompt[0] else {
+ panic!("expected text content block");
+ };
+ thread.update(&mut cx, |thread, cx| {
+ thread
+ .handle_session_update(
+ acp::SessionUpdate::AgentMessageChunk {
+ content: content.text.to_uppercase().into(),
+ },
+ cx,
+ )
+ .unwrap();
+ })?;
+ Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
+ }
+ .boxed_local()
+ }
+ }));
+ let thread = cx
+ .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
+ .await
+ .unwrap();
+
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ hello
+
+ ## Assistant
+
+ HELLO
+
+ "}
+ );
+ });
+
+ // Simulate refusing the second message, ensuring the conversation gets
+ // truncated to before sending it.
+ refuse_next.store(true, SeqCst);
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ hello
+
+ ## Assistant
+
+ HELLO
+
+ "}
+ );
+ });
+ }
+
async fn run_until_first_tool_call(
thread: &Entity<AcpThread>,
cx: &mut TestAppContext,
@@ -2398,7 +2398,6 @@ impl AcpThreadView {
})
.when(!changed_buffers.is_empty(), |this| {
this.child(self.render_edits_summary(
- action_log,
&changed_buffers,
self.edits_expanded,
pending_edits,
@@ -2550,7 +2549,6 @@ impl AcpThreadView {
fn render_edits_summary(
&self,
- action_log: &Entity<ActionLog>,
changed_buffers: &BTreeMap<Entity<Buffer>, Entity<BufferDiff>>,
expanded: bool,
pending_edits: bool,
@@ -2661,14 +2659,9 @@ impl AcpThreadView {
)
.map(|kb| kb.size(rems_from_px(10.))),
)
- .on_click({
- let action_log = action_log.clone();
- cx.listener(move |_, _, _, cx| {
- action_log.update(cx, |action_log, cx| {
- action_log.reject_all_edits(cx).detach();
- })
- })
- }),
+ .on_click(cx.listener(move |this, _, window, cx| {
+ this.reject_all(&RejectAll, window, cx);
+ })),
)
.child(
Button::new("keep-all-changes", "Keep All")
@@ -2681,14 +2674,9 @@ impl AcpThreadView {
KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(10.))),
)
- .on_click({
- let action_log = action_log.clone();
- cx.listener(move |_, _, _, cx| {
- action_log.update(cx, |action_log, cx| {
- action_log.keep_all_edits(cx);
- })
- })
- }),
+ .on_click(cx.listener(move |this, _, window, cx| {
+ this.keep_all(&KeepAll, window, cx);
+ })),
),
)
}
@@ -3014,6 +3002,24 @@ impl AcpThreadView {
});
}
+ fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
+ let Some(thread) = self.thread() else {
+ return;
+ };
+ let action_log = thread.read(cx).action_log().clone();
+ action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
+ }
+
+ fn reject_all(&mut self, _: &RejectAll, _window: &mut Window, cx: &mut Context<Self>) {
+ let Some(thread) = self.thread() else {
+ return;
+ };
+ let action_log = thread.read(cx).action_log().clone();
+ action_log
+ .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
+ .detach();
+ }
+
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.as_native_thread(cx)?.read(cx);
@@ -3952,6 +3958,8 @@ impl Render for AcpThreadView {
.key_context("AcpThread")
.on_action(cx.listener(Self::open_agent_diff))
.on_action(cx.listener(Self::toggle_burn_mode))
+ .on_action(cx.listener(Self::keep_all))
+ .on_action(cx.listener(Self::reject_all))
.bg(cx.theme().colors().panel_background)
.child(match &self.thread_state {
ThreadState::Unauthenticated {