Detailed changes
@@ -141,6 +141,7 @@ pub enum ToolCallStatus {
status: acp::ToolCallStatus,
},
Rejected,
+ Canceled,
}
#[derive(Debug)]
@@ -359,6 +360,7 @@ pub struct AcpThread {
server: Arc<AcpServer>,
title: SharedString,
project: Entity<Project>,
+ send_task: Option<Task<()>>,
}
enum AcpThreadEvent {
@@ -366,6 +368,13 @@ enum AcpThreadEvent {
EntryUpdated(usize),
}
+#[derive(PartialEq, Eq)]
+pub enum ThreadStatus {
+ Idle,
+ WaitingForToolConfirmation,
+ Generating,
+}
+
impl EventEmitter<AcpThreadEvent> for AcpThread {}
impl AcpThread {
@@ -378,7 +387,7 @@ impl AcpThread {
) -> Self {
let mut next_entry_id = ThreadEntryId(0);
Self {
- title: "A new agent2 thread".into(),
+ title: "ACP Thread".into(),
entries: entries
.into_iter()
.map(|entry| ThreadEntry {
@@ -390,6 +399,7 @@ impl AcpThread {
id: thread_id,
next_entry_id,
project,
+ send_task: None,
}
}
@@ -401,6 +411,18 @@ impl AcpThread {
&self.entries
}
+ pub fn status(&self) -> ThreadStatus {
+ if self.send_task.is_some() {
+ if self.waiting_for_tool_confirmation() {
+ ThreadStatus::WaitingForToolConfirmation
+ } else {
+ ThreadStatus::Generating
+ }
+ } else {
+ ThreadStatus::Idle
+ }
+ }
+
pub fn push_entry(
&mut self,
entry: AgentThreadEntryContent,
@@ -577,6 +599,10 @@ impl AcpThread {
ToolCallStatus::Rejected => {
anyhow::bail!("Tool call was rejected and therefore can't be updated")
}
+ ToolCallStatus::Canceled => {
+ // todo! test this case with fake server
+ call.status = ToolCallStatus::Allowed { status: new_status };
+ }
}
}
_ => anyhow::bail!("Entry is not a tool call"),
@@ -597,11 +623,14 @@ impl AcpThread {
/// Returns true if the last turn is awaiting tool authorization
pub fn waiting_for_tool_confirmation(&self) -> bool {
+ // todo!("should we use a hashmap?")
for entry in self.entries.iter().rev() {
match &entry.content {
AgentThreadEntryContent::ToolCall(call) => match call.status {
ToolCallStatus::WaitingForConfirmation { .. } => return true,
- ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
+ ToolCallStatus::Allowed { .. }
+ | ToolCallStatus::Rejected
+ | ToolCallStatus::Canceled => continue,
},
AgentThreadEntryContent::Message(_) => {
// Reached the beginning of the turn
@@ -612,9 +641,14 @@ impl AcpThread {
false
}
- pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
+ pub fn send(
+ &mut self,
+ message: &str,
+ cx: &mut Context<Self>,
+ ) -> impl use<> + Future<Output = Result<()>> {
let agent = self.server.clone();
let id = self.id.clone();
+
let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
let message = Message {
role: Role::User,
@@ -622,10 +656,65 @@ impl AcpThread {
};
self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
let acp_message = message.into_acp(cx);
- cx.spawn(async move |_, cx| {
- agent.send_message(id, acp_message, cx).await?;
- Ok(())
- })
+
+ let (tx, rx) = oneshot::channel();
+ let cancel = self.cancel(cx);
+
+ self.send_task = Some(cx.spawn(async move |this, cx| {
+ cancel.await.log_err();
+
+ let result = agent.send_message(id, acp_message, cx).await;
+ tx.send(result).log_err();
+ this.update(cx, |this, _cx| this.send_task.take()).log_err();
+ }));
+
+ async move {
+ match rx.await {
+ Ok(result) => result,
+ Err(_) => Ok(()),
+ }
+ }
+ }
+
+ pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let agent = self.server.clone();
+ let id = self.id.clone();
+
+ if self.send_task.take().is_some() {
+ cx.spawn(async move |this, cx| {
+ agent.cancel_send_message(id, cx).await?;
+
+ this.update(cx, |this, _cx| {
+ for entry in this.entries.iter_mut() {
+ if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
+ let cancel = matches!(
+ call.status,
+ ToolCallStatus::WaitingForConfirmation { .. }
+ | ToolCallStatus::Allowed {
+ status: acp::ToolCallStatus::Running
+ }
+ );
+
+ if cancel {
+ let curr_status =
+ mem::replace(&mut call.status, ToolCallStatus::Canceled);
+
+ if let ToolCallStatus::WaitingForConfirmation {
+ respond_tx, ..
+ } = curr_status
+ {
+ respond_tx
+ .send(acp::ToolCallConfirmationOutcome::Cancel)
+ .ok();
+ }
+ }
+ }
+ }
+ })
+ })
+ } else {
+ Task::ready(Ok(()))
+ }
}
}
@@ -815,6 +904,73 @@ mod tests {
});
}
+ #[gpui::test]
+ async fn test_gemini_cancel(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+ let server = gemini_acp_server(project.clone(), cx).await;
+ let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+ let full_turn = thread.update(cx, |thread, cx| {
+ thread.send(r#"Run `echo "Hello, world!"`"#, cx)
+ });
+
+ run_until_tool_call(&thread, cx).await;
+
+ thread.read_with(cx, |thread, _cx| {
+ let AgentThreadEntryContent::ToolCall(ToolCall {
+ id,
+ status:
+ ToolCallStatus::WaitingForConfirmation {
+ confirmation: ToolCallConfirmation::Execute { root_command, .. },
+ ..
+ },
+ ..
+ }) = &thread.entries()[1].content
+ else {
+ panic!();
+ };
+
+ assert_eq!(root_command, "echo");
+
+ *id
+ });
+
+ thread
+ .update(cx, |thread, cx| thread.cancel(cx))
+ .await
+ .unwrap();
+ full_turn.await.unwrap();
+ thread.read_with(cx, |thread, _| {
+ let AgentThreadEntryContent::ToolCall(ToolCall {
+ status: ToolCallStatus::Canceled,
+ ..
+ }) = &thread.entries()[1].content
+ else {
+ panic!();
+ };
+ });
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(r#"Stop running and say goodbye to me."#, cx)
+ })
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, _| {
+ let AgentThreadEntryContent::Message(Message {
+ role: Role::Assistant,
+ ..
+ }) = &thread.entries()[3].content
+ else {
+ panic!();
+ };
+ });
+ }
+
async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
let (mut tx, mut rx) = mpsc::channel::<()>(1);
@@ -20,23 +20,14 @@ pub struct AcpServer {
}
struct AcpClientDelegate {
- project: Entity<Project>,
threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
cx: AsyncApp,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl AcpClientDelegate {
- fn new(
- project: Entity<Project>,
- threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
- cx: AsyncApp,
- ) -> Self {
- Self {
- project,
- threads,
- cx: cx,
- }
+ fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
+ Self { threads, cx: cx }
}
fn update_thread<R>(
@@ -143,7 +134,7 @@ impl AcpServer {
let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
- AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
+ AcpClientDelegate::new(threads.clone(), cx.to_async()),
stdin,
stdout,
);
@@ -193,14 +184,14 @@ impl AcpServer {
let thread_id: ThreadId = response.thread_id.into();
let server = self.clone();
- let thread = cx.new(|_| AcpThread {
- // todo!
- title: "ACP Thread".into(),
- id: thread_id.clone(), // Either<ErrorState, Id>
- next_entry_id: ThreadEntryId(0),
- entries: Vec::default(),
- project: self.project.clone(),
- server,
+ let thread = cx.new(|cx| {
+ AcpThread::new(
+ server,
+ thread_id.clone(),
+ Vec::default(),
+ self.project.clone(),
+ cx,
+ )
})?;
self.threads.lock().insert(thread_id, thread.downgrade());
Ok(thread)
@@ -222,6 +213,16 @@ impl AcpServer {
Ok(())
}
+ pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
+ self.connection
+ .request(acp::CancelSendMessageParams {
+ thread_id: thread_id.clone().into(),
+ })
+ .await
+ .map_err(to_anyhow)?;
+ Ok(())
+ }
+
pub fn exit_status(&self) -> Option<ExitStatus> {
*self.exit_status.lock()
}
@@ -4,7 +4,6 @@ use std::sync::Arc;
use std::time::Duration;
use agentic_coding_protocol::{self as acp};
-use anyhow::Result;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
use gpui::{
Animation, AnimationExt, App, EdgesRefinement, Empty, Entity, Focusable, ListState,
@@ -25,7 +24,8 @@ use zed_actions::agent::Chat;
use crate::{
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, Diff, MessageChunk, Role,
- ThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus,
+ ThreadEntry, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId,
+ ToolCallStatus,
};
pub struct AcpThreadView {
@@ -36,7 +36,6 @@ pub struct AcpThreadView {
message_editor: Entity<Editor>,
last_error: Option<Entity<Markdown>>,
list_state: ListState,
- send_task: Option<Task<Result<()>>>,
auth_task: Option<Task<()>>,
}
@@ -123,7 +122,6 @@ impl AcpThreadView {
agent,
message_editor,
thread_entry_views: Vec::new(),
- send_task: None,
list_state: list_state,
last_error: None,
auth_task: None,
@@ -203,8 +201,12 @@ impl AcpThreadView {
}
}
- pub fn cancel(&mut self) {
- self.send_task.take();
+ pub fn cancel(&mut self, cx: &mut Context<Self>) {
+ self.last_error.take();
+
+ if let Some(thread) = self.thread() {
+ thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
+ }
}
fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context<Self>) {
@@ -217,7 +219,7 @@ impl AcpThreadView {
let task = thread.update(cx, |thread, cx| thread.send(&text, cx));
- self.send_task = Some(cx.spawn(async move |this, cx| {
+ cx.spawn(async move |this, cx| {
let result = task.await;
this.update(cx, |this, cx| {
@@ -227,9 +229,9 @@ impl AcpThreadView {
Markdown::new(format!("Error: {err}").into(), None, None, cx)
}))
}
- this.send_task.take();
})
- }));
+ })
+ .detach();
self.message_editor.update(cx, |editor, cx| {
editor.clear(window, cx);
@@ -467,6 +469,7 @@ impl AcpThreadView {
.size(IconSize::Small)
.into_any_element(),
ToolCallStatus::Rejected
+ | ToolCallStatus::Canceled
| ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Error,
..
@@ -487,15 +490,17 @@ impl AcpThreadView {
cx,
))
}
- ToolCallStatus::Allowed { .. } => tool_call.content.as_ref().map(|content| {
- div()
- .border_color(cx.theme().colors().border)
- .border_t_1()
- .px_2()
- .py_1p5()
- .child(self.render_tool_call_content(entry_ix, content, window, cx))
- .into_any_element()
- }),
+ ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
+ tool_call.content.as_ref().map(|content| {
+ div()
+ .border_color(cx.theme().colors().border)
+ .border_t_1()
+ .px_2()
+ .py_1p5()
+ .child(self.render_tool_call_content(entry_ix, content, window, cx))
+ .into_any_element()
+ })
+ }
ToolCallStatus::Rejected => None,
};
@@ -1016,18 +1021,21 @@ impl Render for AcpThreadView {
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
.flex_grow(),
)
- .child(div().px_3().children(if self.send_task.is_none() {
- None
- } else {
- Label::new(if thread.read(cx).waiting_for_tool_confirmation() {
- "Waiting for tool confirmation"
- } else {
- "Generating..."
- })
- .color(Color::Muted)
- .size(LabelSize::Small)
- .into()
- })),
+ .child(
+ div().px_3().children(match thread.read(cx).status() {
+ ThreadStatus::Idle => None,
+ ThreadStatus::WaitingForToolConfirmation => {
+ Label::new("Waiting for tool confirmation")
+ .color(Color::Muted)
+ .size(LabelSize::Small)
+ .into()
+ }
+ ThreadStatus::Generating => Label::new("Generating...")
+ .color(Color::Muted)
+ .size(LabelSize::Small)
+ .into(),
+ }),
+ ),
})
.when_some(self.last_error.clone(), |el, error| {
el.child(
@@ -1052,40 +1060,47 @@ impl Render for AcpThreadView {
.p_2()
.gap_2()
.child(self.message_editor.clone())
- .child(h_flex().justify_end().child(if self.send_task.is_some() {
- IconButton::new("stop-generation", IconName::StopFilled)
- .icon_color(Color::Error)
- .style(ButtonStyle::Tinted(ui::TintColor::Error))
- .tooltip(move |window, cx| {
- Tooltip::for_action(
- "Stop Generation",
- &editor::actions::Cancel,
- window,
- cx,
- )
- })
- .disabled(is_editor_empty)
- .on_click(cx.listener(|this, _event, _, _| this.cancel()))
- } else {
- IconButton::new("send-message", IconName::Send)
- .icon_color(Color::Accent)
- .style(ButtonStyle::Filled)
- .disabled(is_editor_empty)
- .on_click({
- let focus_handle = focus_handle.clone();
- move |_event, window, cx| {
- focus_handle.dispatch_action(&Chat, window, cx);
- }
- })
- .when(!is_editor_empty, |button| {
- button.tooltip(move |window, cx| {
- Tooltip::for_action("Send", &Chat, window, cx)
- })
- })
- .when(is_editor_empty, |button| {
- button.tooltip(Tooltip::text("Type a message to submit"))
- })
- })),
+ .child({
+ let thread = self.thread();
+
+ h_flex().justify_end().child(
+ if thread.map_or(true, |thread| {
+ thread.read(cx).status() == ThreadStatus::Idle
+ }) {
+ IconButton::new("send-message", IconName::Send)
+ .icon_color(Color::Accent)
+ .style(ButtonStyle::Filled)
+ .disabled(thread.is_none() || is_editor_empty)
+ .on_click({
+ let focus_handle = focus_handle.clone();
+ move |_event, window, cx| {
+ focus_handle.dispatch_action(&Chat, window, cx);
+ }
+ })
+ .when(!is_editor_empty, |button| {
+ button.tooltip(move |window, cx| {
+ Tooltip::for_action("Send", &Chat, window, cx)
+ })
+ })
+ .when(is_editor_empty, |button| {
+ button.tooltip(Tooltip::text("Type a message to submit"))
+ })
+ } else {
+ IconButton::new("stop-generation", IconName::StopFilled)
+ .icon_color(Color::Error)
+ .style(ButtonStyle::Tinted(ui::TintColor::Error))
+ .tooltip(move |window, cx| {
+ Tooltip::for_action(
+ "Stop Generation",
+ &editor::actions::Cancel,
+ window,
+ cx,
+ )
+ })
+ .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx)))
+ },
+ )
+ }),
)
}
}
@@ -753,7 +753,7 @@ impl AgentPanel {
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
ActiveView::AcpThread { thread_view, .. } => {
- thread_view.update(cx, |thread_element, _cx| thread_element.cancel());
+ thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx));
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
}