diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index 84df93a61061dea004f060c18e42a6d3c60eaf86..bf1bb0bf58391b9a49121a90e98a45b8c63304db 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -141,6 +141,7 @@ pub enum ToolCallStatus { status: acp::ToolCallStatus, }, Rejected, + Canceled, } #[derive(Debug)] @@ -359,6 +360,7 @@ pub struct AcpThread { server: Arc, title: SharedString, project: Entity, + send_task: Option>, } enum AcpThreadEvent { @@ -366,6 +368,13 @@ enum AcpThreadEvent { EntryUpdated(usize), } +#[derive(PartialEq, Eq)] +pub enum ThreadStatus { + Idle, + WaitingForToolConfirmation, + Generating, +} + impl EventEmitter for AcpThread {} impl AcpThread { @@ -378,7 +387,7 @@ impl AcpThread { ) -> Self { let mut next_entry_id = ThreadEntryId(0); Self { - title: "A new agent2 thread".into(), + title: "ACP Thread".into(), entries: entries .into_iter() .map(|entry| ThreadEntry { @@ -390,6 +399,7 @@ impl AcpThread { id: thread_id, next_entry_id, project, + send_task: None, } } @@ -401,6 +411,18 @@ impl AcpThread { &self.entries } + pub fn status(&self) -> ThreadStatus { + if self.send_task.is_some() { + if self.waiting_for_tool_confirmation() { + ThreadStatus::WaitingForToolConfirmation + } else { + ThreadStatus::Generating + } + } else { + ThreadStatus::Idle + } + } + pub fn push_entry( &mut self, entry: AgentThreadEntryContent, @@ -577,6 +599,10 @@ impl AcpThread { ToolCallStatus::Rejected => { anyhow::bail!("Tool call was rejected and therefore can't be updated") } + ToolCallStatus::Canceled => { + // todo! test this case with fake server + call.status = ToolCallStatus::Allowed { status: new_status }; + } } } _ => anyhow::bail!("Entry is not a tool call"), @@ -597,11 +623,14 @@ impl AcpThread { /// Returns true if the last turn is awaiting tool authorization pub fn waiting_for_tool_confirmation(&self) -> bool { + // todo!("should we use a hashmap?") for entry in self.entries.iter().rev() { match &entry.content { AgentThreadEntryContent::ToolCall(call) => match call.status { ToolCallStatus::WaitingForConfirmation { .. } => return true, - ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue, + ToolCallStatus::Allowed { .. } + | ToolCallStatus::Rejected + | ToolCallStatus::Canceled => continue, }, AgentThreadEntryContent::Message(_) => { // Reached the beginning of the turn @@ -612,9 +641,14 @@ impl AcpThread { false } - pub fn send(&mut self, message: &str, cx: &mut Context) -> Task> { + pub fn send( + &mut self, + message: &str, + cx: &mut Context, + ) -> impl use<> + Future> { let agent = self.server.clone(); let id = self.id.clone(); + let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx); let message = Message { role: Role::User, @@ -622,10 +656,65 @@ impl AcpThread { }; self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx); let acp_message = message.into_acp(cx); - cx.spawn(async move |_, cx| { - agent.send_message(id, acp_message, cx).await?; - Ok(()) - }) + + let (tx, rx) = oneshot::channel(); + let cancel = self.cancel(cx); + + self.send_task = Some(cx.spawn(async move |this, cx| { + cancel.await.log_err(); + + let result = agent.send_message(id, acp_message, cx).await; + tx.send(result).log_err(); + this.update(cx, |this, _cx| this.send_task.take()).log_err(); + })); + + async move { + match rx.await { + Ok(result) => result, + Err(_) => Ok(()), + } + } + } + + pub fn cancel(&mut self, cx: &mut Context) -> Task> { + let agent = self.server.clone(); + let id = self.id.clone(); + + if self.send_task.take().is_some() { + cx.spawn(async move |this, cx| { + agent.cancel_send_message(id, cx).await?; + + this.update(cx, |this, _cx| { + for entry in this.entries.iter_mut() { + if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running + } + ); + + if cancel { + let curr_status = + mem::replace(&mut call.status, ToolCallStatus::Canceled); + + if let ToolCallStatus::WaitingForConfirmation { + respond_tx, .. + } = curr_status + { + respond_tx + .send(acp::ToolCallConfirmationOutcome::Cancel) + .ok(); + } + } + } + } + }) + }) + } else { + Task::ready(Ok(())) + } } } @@ -815,6 +904,73 @@ mod tests { }); } + #[gpui::test] + async fn test_gemini_cancel(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let server = gemini_acp_server(project.clone(), cx).await; + let thread = server.create_thread(&mut cx.to_async()).await.unwrap(); + let full_turn = thread.update(cx, |thread, cx| { + thread.send(r#"Run `echo "Hello, world!"`"#, cx) + }); + + run_until_tool_call(&thread, cx).await; + + thread.read_with(cx, |thread, _cx| { + let AgentThreadEntryContent::ToolCall(ToolCall { + id, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + .. + }) = &thread.entries()[1].content + else { + panic!(); + }; + + assert_eq!(root_command, "echo"); + + *id + }); + + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .await + .unwrap(); + full_turn.await.unwrap(); + thread.read_with(cx, |thread, _| { + let AgentThreadEntryContent::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) = &thread.entries()[1].content + else { + panic!(); + }; + }); + + thread + .update(cx, |thread, cx| { + thread.send(r#"Stop running and say goodbye to me."#, cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _| { + let AgentThreadEntryContent::Message(Message { + role: Role::Assistant, + .. + }) = &thread.entries()[3].content + else { + panic!(); + }; + }); + } + async fn run_until_tool_call(thread: &Entity, cx: &mut TestAppContext) { let (mut tx, mut rx) = mpsc::channel::<()>(1); diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index 139db5afcfd8858fcae7e868696c55a66295f440..79bea3b5ba535a27db096b4dc83788c97d778f76 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -20,23 +20,14 @@ pub struct AcpServer { } struct AcpClientDelegate { - project: Entity, threads: Arc>>>, cx: AsyncApp, // sent_buffer_versions: HashMap, HashMap>, } impl AcpClientDelegate { - fn new( - project: Entity, - threads: Arc>>>, - cx: AsyncApp, - ) -> Self { - Self { - project, - threads, - cx: cx, - } + fn new(threads: Arc>>>, cx: AsyncApp) -> Self { + Self { threads, cx: cx } } fn update_thread( @@ -143,7 +134,7 @@ impl AcpServer { let threads: Arc>>> = Default::default(); let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()), + AcpClientDelegate::new(threads.clone(), cx.to_async()), stdin, stdout, ); @@ -193,14 +184,14 @@ impl AcpServer { let thread_id: ThreadId = response.thread_id.into(); let server = self.clone(); - let thread = cx.new(|_| AcpThread { - // todo! - title: "ACP Thread".into(), - id: thread_id.clone(), // Either - next_entry_id: ThreadEntryId(0), - entries: Vec::default(), - project: self.project.clone(), - server, + let thread = cx.new(|cx| { + AcpThread::new( + server, + thread_id.clone(), + Vec::default(), + self.project.clone(), + cx, + ) })?; self.threads.lock().insert(thread_id, thread.downgrade()); Ok(thread) @@ -222,6 +213,16 @@ impl AcpServer { Ok(()) } + pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> { + self.connection + .request(acp::CancelSendMessageParams { + thread_id: thread_id.clone().into(), + }) + .await + .map_err(to_anyhow)?; + Ok(()) + } + pub fn exit_status(&self) -> Option { *self.exit_status.lock() } diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index 0219470fe1005783ce408893222b259ab7b9d09f..86e6bd93a060ce3b24d481e7152cca93e11d0655 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use std::time::Duration; use agentic_coding_protocol::{self as acp}; -use anyhow::Result; use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer}; use gpui::{ Animation, AnimationExt, App, EdgesRefinement, Empty, Entity, Focusable, ListState, @@ -25,7 +24,8 @@ use zed_actions::agent::Chat; use crate::{ AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, Diff, MessageChunk, Role, - ThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus, + ThreadEntry, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, + ToolCallStatus, }; pub struct AcpThreadView { @@ -36,7 +36,6 @@ pub struct AcpThreadView { message_editor: Entity, last_error: Option>, list_state: ListState, - send_task: Option>>, auth_task: Option>, } @@ -123,7 +122,6 @@ impl AcpThreadView { agent, message_editor, thread_entry_views: Vec::new(), - send_task: None, list_state: list_state, last_error: None, auth_task: None, @@ -203,8 +201,12 @@ impl AcpThreadView { } } - pub fn cancel(&mut self) { - self.send_task.take(); + pub fn cancel(&mut self, cx: &mut Context) { + self.last_error.take(); + + if let Some(thread) = self.thread() { + thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + } } fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context) { @@ -217,7 +219,7 @@ impl AcpThreadView { let task = thread.update(cx, |thread, cx| thread.send(&text, cx)); - self.send_task = Some(cx.spawn(async move |this, cx| { + cx.spawn(async move |this, cx| { let result = task.await; this.update(cx, |this, cx| { @@ -227,9 +229,9 @@ impl AcpThreadView { Markdown::new(format!("Error: {err}").into(), None, None, cx) })) } - this.send_task.take(); }) - })); + }) + .detach(); self.message_editor.update(cx, |editor, cx| { editor.clear(window, cx); @@ -467,6 +469,7 @@ impl AcpThreadView { .size(IconSize::Small) .into_any_element(), ToolCallStatus::Rejected + | ToolCallStatus::Canceled | ToolCallStatus::Allowed { status: acp::ToolCallStatus::Error, .. @@ -487,15 +490,17 @@ impl AcpThreadView { cx, )) } - ToolCallStatus::Allowed { .. } => tool_call.content.as_ref().map(|content| { - div() - .border_color(cx.theme().colors().border) - .border_t_1() - .px_2() - .py_1p5() - .child(self.render_tool_call_content(entry_ix, content, window, cx)) - .into_any_element() - }), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + tool_call.content.as_ref().map(|content| { + div() + .border_color(cx.theme().colors().border) + .border_t_1() + .px_2() + .py_1p5() + .child(self.render_tool_call_content(entry_ix, content, window, cx)) + .into_any_element() + }) + } ToolCallStatus::Rejected => None, }; @@ -1016,18 +1021,21 @@ impl Render for AcpThreadView { .with_sizing_behavior(gpui::ListSizingBehavior::Auto) .flex_grow(), ) - .child(div().px_3().children(if self.send_task.is_none() { - None - } else { - Label::new(if thread.read(cx).waiting_for_tool_confirmation() { - "Waiting for tool confirmation" - } else { - "Generating..." - }) - .color(Color::Muted) - .size(LabelSize::Small) - .into() - })), + .child( + div().px_3().children(match thread.read(cx).status() { + ThreadStatus::Idle => None, + ThreadStatus::WaitingForToolConfirmation => { + Label::new("Waiting for tool confirmation") + .color(Color::Muted) + .size(LabelSize::Small) + .into() + } + ThreadStatus::Generating => Label::new("Generating...") + .color(Color::Muted) + .size(LabelSize::Small) + .into(), + }), + ), }) .when_some(self.last_error.clone(), |el, error| { el.child( @@ -1052,40 +1060,47 @@ impl Render for AcpThreadView { .p_2() .gap_2() .child(self.message_editor.clone()) - .child(h_flex().justify_end().child(if self.send_task.is_some() { - IconButton::new("stop-generation", IconName::StopFilled) - .icon_color(Color::Error) - .style(ButtonStyle::Tinted(ui::TintColor::Error)) - .tooltip(move |window, cx| { - Tooltip::for_action( - "Stop Generation", - &editor::actions::Cancel, - window, - cx, - ) - }) - .disabled(is_editor_empty) - .on_click(cx.listener(|this, _event, _, _| this.cancel())) - } else { - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled(is_editor_empty) - .on_click({ - let focus_handle = focus_handle.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action(&Chat, window, cx); - } - }) - .when(!is_editor_empty, |button| { - button.tooltip(move |window, cx| { - Tooltip::for_action("Send", &Chat, window, cx) - }) - }) - .when(is_editor_empty, |button| { - button.tooltip(Tooltip::text("Type a message to submit")) - }) - })), + .child({ + let thread = self.thread(); + + h_flex().justify_end().child( + if thread.map_or(true, |thread| { + thread.read(cx).status() == ThreadStatus::Idle + }) { + IconButton::new("send-message", IconName::Send) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled(thread.is_none() || is_editor_empty) + .on_click({ + let focus_handle = focus_handle.clone(); + move |_event, window, cx| { + focus_handle.dispatch_action(&Chat, window, cx); + } + }) + .when(!is_editor_empty, |button| { + button.tooltip(move |window, cx| { + Tooltip::for_action("Send", &Chat, window, cx) + }) + }) + .when(is_editor_empty, |button| { + button.tooltip(Tooltip::text("Type a message to submit")) + }) + } else { + IconButton::new("stop-generation", IconName::StopFilled) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action( + "Stop Generation", + &editor::actions::Cancel, + window, + cx, + ) + }) + .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) + }, + ) + }), ) } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index ce38ff023b80cc7e28e25bde2f4c3c589b7739d2..ae718bd7bf7c741079c8cf70ef16fccb03698236 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -753,7 +753,7 @@ impl AgentPanel { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } ActiveView::AcpThread { thread_view, .. } => { - thread_view.update(cx, |thread_element, _cx| thread_element.cancel()); + thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} }