diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index ac9c1954bbf9c5811fb6896d36efef2a8a3a233f..cf0f81ddd2030d87f178f9bdea8b8858f3ce0d0a 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -121,10 +121,15 @@ pub enum AgentThreadEntryContent { } #[derive(Debug)] -pub enum ToolCall { +pub struct ToolCall { + id: ToolCallId, + tool_name: Entity, + status: ToolCallStatus, +} + +#[derive(Debug)] +pub enum ToolCallStatus { WaitingForConfirmation { - id: ToolCallId, - tool_name: Entity, description: Entity, respond_tx: oneshot::Sender, }, @@ -270,21 +275,23 @@ impl AcpThread { let language_registry = self.project.read(cx).languages().clone(); let entry_id = self.push_entry( - AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation { + AgentThreadEntryContent::ToolCall(ToolCall { // todo! clean up id creation id: ToolCallId(ThreadEntryId(self.entries.len() as u64)), tool_name: cx.new(|cx| { Markdown::new(title.into(), Some(language_registry.clone()), None, cx) }), - description: cx.new(|cx| { - Markdown::new( - description.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - respond_tx, + status: ToolCallStatus::WaitingForConfirmation { + description: cx.new(|cx| { + Markdown::new( + description.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + respond_tx, + }, }), cx, ); @@ -302,21 +309,21 @@ impl AcpThread { return; }; - let new_state = if allowed { - ToolCall::Allowed + let new_status = if allowed { + ToolCallStatus::Allowed } else { - ToolCall::Rejected + ToolCallStatus::Rejected }; - let call = mem::replace(call, new_state); + let curr_status = mem::replace(&mut call.status, new_status); - if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call { + if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { respond_tx.send(allowed).log_err(); } else { debug_panic!("tried to authorize an already authorized tool call"); } - cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize)); + cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize)); } fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> { @@ -426,11 +433,10 @@ mod tests { run_until_tool_call(&thread, cx).await; let tool_call_id = thread.read_with(cx, |thread, cx| { - let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation { + let AgentThreadEntryContent::ToolCall(ToolCall { id, tool_name, - description, - .. + status: ToolCallStatus::WaitingForConfirmation { description, .. }, }) = &thread.entries().last().unwrap().content else { panic!(); @@ -454,7 +460,10 @@ mod tests { thread.authorize_tool_call(tool_call_id, true, cx); assert!(matches!( thread.entries().last().unwrap().content, - AgentThreadEntryContent::ToolCall(ToolCall::Allowed) + AgentThreadEntryContent::ToolCall(ToolCall { + status: ToolCallStatus::Allowed, + .. + }) )); }); @@ -471,7 +480,10 @@ mod tests { )); assert!(matches!( thread.entries[1].content, - AgentThreadEntryContent::ToolCall(ToolCall::Allowed) + AgentThreadEntryContent::ToolCall(ToolCall { + status: ToolCallStatus::Allowed, + .. + }) )); assert!(matches!( thread.entries[2].content, diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index 32aa14e45effe1089f61733d8f3397377d36f8df..70458acd3a45336efb9460934e6b67691bfe730c 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -20,7 +20,7 @@ use zed_actions::agent::Chat; use crate::{ AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry, - ToolCall, ToolCallId, + ToolCall, ToolCallId, ToolCallStatus, }; pub struct AcpThreadView { @@ -224,7 +224,7 @@ impl AcpThreadView { match message.role { Role::User => div() .p_2() - .pt_4() + .pt_5() .child( div() .text_xs() @@ -245,47 +245,99 @@ impl AcpThreadView { .into_any(), } } - AgentThreadEntryContent::ToolCall(tool_call) => match tool_call { - ToolCall::WaitingForConfirmation { - id, - tool_name, - description, - .. - } => { - let id = *id; - v_flex() - .elevation_1(cx) - .child(MarkdownElement::new( - tool_name.clone(), - default_markdown_style(window, cx), - )) - .child(MarkdownElement::new( - description.clone(), - default_markdown_style(window, cx), - )) + AgentThreadEntryContent::ToolCall(tool_call) => div() + .px_2() + .py_4() + .child(self.render_tool_call(tool_call, window, cx)) + .into_any(), + } + } + + fn render_tool_call(&self, tool_call: &ToolCall, window: &Window, cx: &Context) -> Div { + let status_icon = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { .. } => Empty.into_element().into_any(), + ToolCallStatus::Allowed => Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::Small) + .into_any_element(), + ToolCallStatus::Rejected => Icon::new(IconName::X) + .color(Color::Error) + .size(IconSize::Small) + .into_any_element(), + }; + + let content = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { description, .. } => v_flex() + .border_color(cx.theme().colors().border) + .border_t_1() + .px_2() + .py_1p5() + .child(MarkdownElement::new( + description.clone(), + default_markdown_style(window, cx), + )) + .child( + h_flex() + .justify_end() + .gap_1() .child( - h_flex() - .child(Button::new(("allow", id.as_u64()), "Allow").on_click( - cx.listener({ - move |this, _, _, cx| { - this.authorize_tool_call(id, true, cx); - } - }), - )) - .child(Button::new(("reject", id.as_u64()), "Reject").on_click( - cx.listener({ - move |this, _, _, cx| { - this.authorize_tool_call(id, false, cx); - } - }), - )), + Button::new(("allow", tool_call.id.as_u64()), "Allow") + .icon(IconName::Check) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Success) + .on_click(cx.listener({ + let id = tool_call.id; + move |this, _, _, cx| { + this.authorize_tool_call(id, true, cx); + } + })), ) - .into_any() - } - ToolCall::Allowed => div().child("Allowed!").into_any(), - ToolCall::Rejected => div().child("Rejected!").into_any(), - }, - } + .child( + Button::new(("reject", tool_call.id.as_u64()), "Reject") + .icon(IconName::X) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .on_click(cx.listener({ + let id = tool_call.id; + move |this, _, _, cx| { + this.authorize_tool_call(id, false, cx); + } + })), + ), + ) + .into_any() + .into(), + ToolCallStatus::Allowed => None, + ToolCallStatus::Rejected => None, + }; + + v_flex() + .text_xs() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .child( + h_flex() + .px_2() + .py_1p5() + .w_full() + .gap_1p5() + .child( + Icon::new(IconName::Cog) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(MarkdownElement::new( + tool_call.tool_name.clone(), + default_markdown_style(window, cx), + )) + .child(div().w_full()) + .child(status_icon), + ) + .children(content) } }