diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index e5a1461e6d3a0bee3f43b46687aee37931508e2d..3799cb82081ff640d3b51b6c9116a07ddfe726f4 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -133,7 +133,7 @@ pub struct ToolCall { #[derive(Debug)] pub enum ToolCallStatus { WaitingForConfirmation { - confirmation: acp::ToolCallConfirmation, + confirmation: ToolCallConfirmation, respond_tx: oneshot::Sender, }, Allowed { @@ -144,18 +144,154 @@ pub enum ToolCallStatus { } #[derive(Debug)] -pub enum ToolCallContent { - Markdown { - markdown: Entity, +pub enum ToolCallConfirmation { + Edit { + diff: Diff, + description: Option>, }, - Diff { - path: PathBuf, - diff: Entity, - buffer: Entity, - _task: Task>, + Execute { + command: String, + root_command: String, + description: Option>, + }, + Mcp { + server_name: String, + tool_name: String, + tool_display_name: String, + description: Option>, + }, + Fetch { + urls: Vec, + description: Option>, + }, + Other { + description: Entity, }, } +impl ToolCallConfirmation { + pub fn from_acp( + confirmation: acp::ToolCallConfirmation, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let to_md = |description: String, cx: &mut App| -> Entity { + cx.new(|cx| { + Markdown::new( + description.into(), + Some(language_registry.clone()), + None, + cx, + ) + }) + }; + + match confirmation { + acp::ToolCallConfirmation::Edit { diff, description } => Self::Edit { + diff: Diff::from_acp(diff, language_registry.clone(), cx), + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Execute { + command, + root_command, + description, + } => Self::Execute { + command, + root_command, + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Mcp { + server_name, + tool_name, + tool_display_name, + description, + } => Self::Mcp { + server_name, + tool_name, + tool_display_name, + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch { + urls, + description: description.map(|description| to_md(description, cx)), + }, + acp::ToolCallConfirmation::Other { description } => Self::Other { + description: to_md(description, cx), + }, + } + } +} + +#[derive(Debug)] +pub enum ToolCallContent { + Markdown { markdown: Entity }, + Diff { diff: Diff }, +} + +#[derive(Debug)] +pub struct Diff { + // todo! show path somewhere + buffer: Entity, + _path: PathBuf, + _task: Task>, +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let text_snapshot = buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx)); + + let multibuffer = cx.new(|cx| { + let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx); + multibuffer.add_diff(buffer_diff.clone(), cx); + multibuffer + }); + + Self { + buffer: multibuffer, + _path: path.clone(), + _task: cx.spawn(async move |cx| { + let diff_snapshot = BufferDiff::update_diff( + buffer_diff.clone(), + text_snapshot.clone(), + old_text.map(|o| o.into()), + true, + true, + None, + Some(language_registry.clone()), + cx, + ) + .await?; + + buffer_diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &text_snapshot, cx) + })?; + + if let Some(language) = language_registry + .language_for_file_path(&path) + .await + .log_err() + { + buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?; + } + + anyhow::Ok(()) + }), + } + } +} + /// A `ThreadEntryId` that is known to be a ToolCall #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ToolCallId(ThreadEntryId); @@ -293,7 +429,11 @@ impl AcpThread { let (tx, rx) = oneshot::channel(); let status = ToolCallStatus::WaitingForConfirmation { - confirmation, + confirmation: ToolCallConfirmation::from_acp( + confirmation, + self.project.read(cx).languages().clone(), + cx, + ), respond_tx: tx, }; @@ -399,56 +539,9 @@ impl AcpThread { ) }), }, - acp::ToolCallContent::Diff { - path, - old_text, - new_text, - } => { - let buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let text_snapshot = buffer.read(cx).text_snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx)); - - let multibuffer = cx.new(|cx| { - let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx); - multibuffer.add_diff(buffer_diff.clone(), cx); - multibuffer - }); - - ToolCallContent::Diff { - path: path.clone(), - diff: buffer_diff.clone(), - buffer: multibuffer, - _task: cx.spawn(async move |_this, cx| { - let diff_snapshot = BufferDiff::update_diff( - buffer_diff.clone(), - text_snapshot.clone(), - old_text.map(|o| o.into()), - true, - true, - None, - Some(language_registry.clone()), - cx, - ) - .await?; - - buffer_diff.update(cx, |diff, cx| { - diff.set_snapshot(diff_snapshot, &text_snapshot, cx) - })?; - - if let Some(language) = language_registry - .language_for_file_path(&path) - .await - .log_err() - { - buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(language), cx) - })?; - } - - anyhow::Ok(()) - }), - } - } + acp::ToolCallContent::Diff { diff } => ToolCallContent::Diff { + diff: Diff::from_acp(diff, language_registry, cx), + }, }); *status = new_status; } @@ -647,7 +740,7 @@ mod tests { id, status: ToolCallStatus::WaitingForConfirmation { - confirmation: acp::ToolCallConfirmation::Execute { root_command, .. }, + confirmation: ToolCallConfirmation::Execute { root_command, .. }, .. }, .. diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index 209e15aed290d995df95339d8967b0244aec58cd..a7a15e8a12a8f5094273aaf3890c873eb7f2f12a 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -2,7 +2,7 @@ use std::path::Path; use std::rc::Rc; use std::time::Duration; -use agentic_coding_protocol::{self as acp, ToolCallConfirmation}; +use agentic_coding_protocol::{self as acp}; use anyhow::Result; use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer}; use gpui::{ @@ -24,7 +24,7 @@ use zed_actions::agent::Chat; use crate::{ AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry, - ToolCall, ToolCallContent, ToolCallId, ToolCallStatus, + ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus, }; pub struct AcpThreadView { @@ -232,21 +232,34 @@ impl AcpThreadView { cx.notify(); } + // todo! should we do this on the fly from render? fn sync_thread_entry_view( &mut self, entry_ix: usize, window: &mut Window, cx: &mut Context, ) { - let Some(buffer) = self.entry_diff_buffer(entry_ix, cx) else { - return; + let buffer = match ( + self.entry_diff_buffer(entry_ix, cx), + self.thread_entry_views.get(entry_ix), + ) { + (Some(buffer), Some(Some(ThreadEntryView::Diff { editor }))) => { + if editor.read(cx).buffer() == &buffer { + // same buffer, all synced up + return; + } + // new buffer, replace editor + buffer + } + (Some(buffer), _) => buffer, + (None, Some(Some(ThreadEntryView::Diff { .. }))) => { + // no longer displaying a diff, drop editor + self.thread_entry_views[entry_ix] = None; + return; + } + (None, _) => return, }; - if let Some(Some(ThreadEntryView::Diff { .. })) = self.thread_entry_views.get(entry_ix) { - return; - } - // todo! should we do this on the fly from render? - let editor = cx.new(|cx| { let mut editor = Editor::new( EditorMode::Full { @@ -297,16 +310,20 @@ impl AcpThreadView { fn entry_diff_buffer(&self, entry_ix: usize, cx: &App) -> Option> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - if let AgentThreadEntryContent::ToolCall(ToolCall { - status: - crate::ToolCallStatus::Allowed { - content: Some(ToolCallContent::Diff { buffer, .. }), - .. - }, - .. - }) = &entry.content - { - Some(buffer.clone()) + if let AgentThreadEntryContent::ToolCall(ToolCall { status, .. }) = &entry.content { + if let ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Edit { diff, .. }, + .. + } + | ToolCallStatus::Allowed { + content: Some(ToolCallContent::Diff { diff }), + .. + } = status + { + Some(diff.buffer.clone()) + } else { + None + } } else { None } @@ -423,7 +440,13 @@ impl AcpThreadView { let content = match &tool_call.status { ToolCallStatus::WaitingForConfirmation { confirmation, .. } => { - Some(self.render_tool_call_confirmation(tool_call.id, confirmation, cx)) + Some(self.render_tool_call_confirmation( + entry_ix, + tool_call.id, + confirmation, + window, + cx, + )) } ToolCallStatus::Allowed { content, .. } => content.as_ref().map(|content| { div() @@ -437,15 +460,7 @@ impl AcpThreadView { default_markdown_style(window, cx), ) .into_any_element(), - ToolCallContent::Diff { .. } => { - if let Some(Some(ThreadEntryView::Diff { editor })) = - self.thread_entry_views.get(entry_ix) - { - editor.clone().into_any_element() - } else { - Empty.into_any() - } - } + ToolCallContent::Diff { .. } => self.render_diff_editor(entry_ix), }) .into_any_element() }), @@ -482,24 +497,25 @@ impl AcpThreadView { fn render_tool_call_confirmation( &self, + entry_ix: usize, tool_call_id: ToolCallId, confirmation: &ToolCallConfirmation, + window: &Window, cx: &Context, ) -> AnyElement { match confirmation { ToolCallConfirmation::Edit { - file_name, - file_diff, description, + diff: _, } => v_flex() .border_color(cx.theme().colors().border) .border_t_1() .px_2() .py_1p5() - // todo! nicer rendering - .child(file_name.clone()) - .child(file_diff.clone()) - .children(description.clone()) + .child(self.render_diff_editor(entry_ix)) + .children(description.clone().map(|description| { + MarkdownElement::new(description, default_markdown_style(window, cx)) + })) .child( h_flex() .justify_end() @@ -571,7 +587,9 @@ impl AcpThreadView { .py_1p5() // todo! nicer rendering .child(command.clone()) - .children(description.clone()) + .children(description.clone().map(|description| { + MarkdownElement::new(description, default_markdown_style(window, cx)) + })) .child( h_flex() .justify_end() @@ -644,7 +662,9 @@ impl AcpThreadView { .py_1p5() // todo! nicer rendering .child(format!("{server_name} - {tool_display_name}")) - .children(description.clone()) + .children(description.clone().map(|description| { + MarkdownElement::new(description, default_markdown_style(window, cx)) + })) .child( h_flex() .justify_end() @@ -732,7 +752,9 @@ impl AcpThreadView { .py_1p5() // todo! nicer rendering .children(urls.clone()) - .children(description.clone()) + .children(description.clone().map(|description| { + MarkdownElement::new(description, default_markdown_style(window, cx)) + })) .child( h_flex() .justify_end() @@ -796,7 +818,10 @@ impl AcpThreadView { .px_2() .py_1p5() // todo! nicer rendering - .child(description.clone()) + .child(MarkdownElement::new( + description.clone(), + default_markdown_style(window, cx), + )) .child( h_flex() .justify_end() @@ -856,6 +881,15 @@ impl AcpThreadView { .into_any(), } } + + fn render_diff_editor(&self, entry_ix: usize) -> AnyElement { + if let Some(Some(ThreadEntryView::Diff { editor })) = self.thread_entry_views.get(entry_ix) + { + editor.clone().into_any_element() + } else { + Empty.into_any() + } + } } impl Focusable for AcpThreadView {