Render confirmation diffs and description as markdown

Agus Zubiaga created

Change summary

crates/acp/src/acp.rs         | 215 ++++++++++++++++++++++++++----------
crates/acp/src/thread_view.rs | 112 ++++++++++++------
2 files changed, 227 insertions(+), 100 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -133,7 +133,7 @@ pub struct ToolCall {
 #[derive(Debug)]
 pub enum ToolCallStatus {
     WaitingForConfirmation {
-        confirmation: acp::ToolCallConfirmation,
+        confirmation: ToolCallConfirmation,
         respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
     },
     Allowed {
@@ -144,18 +144,154 @@ pub enum ToolCallStatus {
 }
 
 #[derive(Debug)]
-pub enum ToolCallContent {
-    Markdown {
-        markdown: Entity<Markdown>,
+pub enum ToolCallConfirmation {
+    Edit {
+        diff: Diff,
+        description: Option<Entity<Markdown>>,
     },
-    Diff {
-        path: PathBuf,
-        diff: Entity<BufferDiff>,
-        buffer: Entity<MultiBuffer>,
-        _task: Task<Result<()>>,
+    Execute {
+        command: String,
+        root_command: String,
+        description: Option<Entity<Markdown>>,
+    },
+    Mcp {
+        server_name: String,
+        tool_name: String,
+        tool_display_name: String,
+        description: Option<Entity<Markdown>>,
+    },
+    Fetch {
+        urls: Vec<String>,
+        description: Option<Entity<Markdown>>,
+    },
+    Other {
+        description: Entity<Markdown>,
     },
 }
 
+impl ToolCallConfirmation {
+    pub fn from_acp(
+        confirmation: acp::ToolCallConfirmation,
+        language_registry: Arc<LanguageRegistry>,
+        cx: &mut App,
+    ) -> Self {
+        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
+            cx.new(|cx| {
+                Markdown::new(
+                    description.into(),
+                    Some(language_registry.clone()),
+                    None,
+                    cx,
+                )
+            })
+        };
+
+        match confirmation {
+            acp::ToolCallConfirmation::Edit { diff, description } => Self::Edit {
+                diff: Diff::from_acp(diff, language_registry.clone(), cx),
+                description: description.map(|description| to_md(description, cx)),
+            },
+            acp::ToolCallConfirmation::Execute {
+                command,
+                root_command,
+                description,
+            } => Self::Execute {
+                command,
+                root_command,
+                description: description.map(|description| to_md(description, cx)),
+            },
+            acp::ToolCallConfirmation::Mcp {
+                server_name,
+                tool_name,
+                tool_display_name,
+                description,
+            } => Self::Mcp {
+                server_name,
+                tool_name,
+                tool_display_name,
+                description: description.map(|description| to_md(description, cx)),
+            },
+            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
+                urls,
+                description: description.map(|description| to_md(description, cx)),
+            },
+            acp::ToolCallConfirmation::Other { description } => Self::Other {
+                description: to_md(description, cx),
+            },
+        }
+    }
+}
+
+#[derive(Debug)]
+pub enum ToolCallContent {
+    Markdown { markdown: Entity<Markdown> },
+    Diff { diff: Diff },
+}
+
+#[derive(Debug)]
+pub struct Diff {
+    // todo! show path somewhere
+    buffer: Entity<MultiBuffer>,
+    _path: PathBuf,
+    _task: Task<Result<()>>,
+}
+
+impl Diff {
+    pub fn from_acp(
+        diff: acp::Diff,
+        language_registry: Arc<LanguageRegistry>,
+        cx: &mut App,
+    ) -> Self {
+        let acp::Diff {
+            path,
+            old_text,
+            new_text,
+        } = diff;
+
+        let buffer = cx.new(|cx| Buffer::local(new_text, cx));
+        let text_snapshot = buffer.read(cx).text_snapshot();
+        let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
+
+        let multibuffer = cx.new(|cx| {
+            let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
+            multibuffer.add_diff(buffer_diff.clone(), cx);
+            multibuffer
+        });
+
+        Self {
+            buffer: multibuffer,
+            _path: path.clone(),
+            _task: cx.spawn(async move |cx| {
+                let diff_snapshot = BufferDiff::update_diff(
+                    buffer_diff.clone(),
+                    text_snapshot.clone(),
+                    old_text.map(|o| o.into()),
+                    true,
+                    true,
+                    None,
+                    Some(language_registry.clone()),
+                    cx,
+                )
+                .await?;
+
+                buffer_diff.update(cx, |diff, cx| {
+                    diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
+                })?;
+
+                if let Some(language) = language_registry
+                    .language_for_file_path(&path)
+                    .await
+                    .log_err()
+                {
+                    buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
+                }
+
+                anyhow::Ok(())
+            }),
+        }
+    }
+}
+
 /// A `ThreadEntryId` that is known to be a ToolCall
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 pub struct ToolCallId(ThreadEntryId);
@@ -293,7 +429,11 @@ impl AcpThread {
         let (tx, rx) = oneshot::channel();
 
         let status = ToolCallStatus::WaitingForConfirmation {
-            confirmation,
+            confirmation: ToolCallConfirmation::from_acp(
+                confirmation,
+                self.project.read(cx).languages().clone(),
+                cx,
+            ),
             respond_tx: tx,
         };
 
@@ -399,56 +539,9 @@ impl AcpThread {
                                 )
                             }),
                         },
-                        acp::ToolCallContent::Diff {
-                            path,
-                            old_text,
-                            new_text,
-                        } => {
-                            let buffer = cx.new(|cx| Buffer::local(new_text, cx));
-                            let text_snapshot = buffer.read(cx).text_snapshot();
-                            let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
-
-                            let multibuffer = cx.new(|cx| {
-                                let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
-                                multibuffer.add_diff(buffer_diff.clone(), cx);
-                                multibuffer
-                            });
-
-                            ToolCallContent::Diff {
-                                path: path.clone(),
-                                diff: buffer_diff.clone(),
-                                buffer: multibuffer,
-                                _task: cx.spawn(async move |_this, cx| {
-                                    let diff_snapshot = BufferDiff::update_diff(
-                                        buffer_diff.clone(),
-                                        text_snapshot.clone(),
-                                        old_text.map(|o| o.into()),
-                                        true,
-                                        true,
-                                        None,
-                                        Some(language_registry.clone()),
-                                        cx,
-                                    )
-                                    .await?;
-
-                                    buffer_diff.update(cx, |diff, cx| {
-                                        diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
-                                    })?;
-
-                                    if let Some(language) = language_registry
-                                        .language_for_file_path(&path)
-                                        .await
-                                        .log_err()
-                                    {
-                                        buffer.update(cx, |buffer, cx| {
-                                            buffer.set_language(Some(language), cx)
-                                        })?;
-                                    }
-
-                                    anyhow::Ok(())
-                                }),
-                            }
-                        }
+                        acp::ToolCallContent::Diff { diff } => ToolCallContent::Diff {
+                            diff: Diff::from_acp(diff, language_registry, cx),
+                        },
                     });
                     *status = new_status;
                 }
@@ -647,7 +740,7 @@ mod tests {
                 id,
                 status:
                     ToolCallStatus::WaitingForConfirmation {
-                        confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
+                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
                         ..
                     },
                 ..

crates/acp/src/thread_view.rs 🔗

@@ -2,7 +2,7 @@ use std::path::Path;
 use std::rc::Rc;
 use std::time::Duration;
 
-use agentic_coding_protocol::{self as acp, ToolCallConfirmation};
+use agentic_coding_protocol::{self as acp};
 use anyhow::Result;
 use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
 use gpui::{
@@ -24,7 +24,7 @@ use zed_actions::agent::Chat;
 
 use crate::{
     AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
-    ToolCall, ToolCallContent, ToolCallId, ToolCallStatus,
+    ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus,
 };
 
 pub struct AcpThreadView {
@@ -232,21 +232,34 @@ impl AcpThreadView {
         cx.notify();
     }
 
+    // todo! should we do this on the fly from render?
     fn sync_thread_entry_view(
         &mut self,
         entry_ix: usize,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        let Some(buffer) = self.entry_diff_buffer(entry_ix, cx) else {
-            return;
+        let buffer = match (
+            self.entry_diff_buffer(entry_ix, cx),
+            self.thread_entry_views.get(entry_ix),
+        ) {
+            (Some(buffer), Some(Some(ThreadEntryView::Diff { editor }))) => {
+                if editor.read(cx).buffer() == &buffer {
+                    // same buffer, all synced up
+                    return;
+                }
+                // new buffer, replace editor
+                buffer
+            }
+            (Some(buffer), _) => buffer,
+            (None, Some(Some(ThreadEntryView::Diff { .. }))) => {
+                // no longer displaying a diff, drop editor
+                self.thread_entry_views[entry_ix] = None;
+                return;
+            }
+            (None, _) => return,
         };
 
-        if let Some(Some(ThreadEntryView::Diff { .. })) = self.thread_entry_views.get(entry_ix) {
-            return;
-        }
-        // todo! should we do this on the fly from render?
-
         let editor = cx.new(|cx| {
             let mut editor = Editor::new(
                 EditorMode::Full {
@@ -297,16 +310,20 @@ impl AcpThreadView {
     fn entry_diff_buffer(&self, entry_ix: usize, cx: &App) -> Option<Entity<MultiBuffer>> {
         let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
 
-        if let AgentThreadEntryContent::ToolCall(ToolCall {
-            status:
-                crate::ToolCallStatus::Allowed {
-                    content: Some(ToolCallContent::Diff { buffer, .. }),
-                    ..
-                },
-            ..
-        }) = &entry.content
-        {
-            Some(buffer.clone())
+        if let AgentThreadEntryContent::ToolCall(ToolCall { status, .. }) = &entry.content {
+            if let ToolCallStatus::WaitingForConfirmation {
+                confirmation: ToolCallConfirmation::Edit { diff, .. },
+                ..
+            }
+            | ToolCallStatus::Allowed {
+                content: Some(ToolCallContent::Diff { diff }),
+                ..
+            } = status
+            {
+                Some(diff.buffer.clone())
+            } else {
+                None
+            }
         } else {
             None
         }
@@ -423,7 +440,13 @@ impl AcpThreadView {
 
         let content = match &tool_call.status {
             ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
-                Some(self.render_tool_call_confirmation(tool_call.id, confirmation, cx))
+                Some(self.render_tool_call_confirmation(
+                    entry_ix,
+                    tool_call.id,
+                    confirmation,
+                    window,
+                    cx,
+                ))
             }
             ToolCallStatus::Allowed { content, .. } => content.as_ref().map(|content| {
                 div()
@@ -437,15 +460,7 @@ impl AcpThreadView {
                             default_markdown_style(window, cx),
                         )
                         .into_any_element(),
-                        ToolCallContent::Diff { .. } => {
-                            if let Some(Some(ThreadEntryView::Diff { editor })) =
-                                self.thread_entry_views.get(entry_ix)
-                            {
-                                editor.clone().into_any_element()
-                            } else {
-                                Empty.into_any()
-                            }
-                        }
+                        ToolCallContent::Diff { .. } => self.render_diff_editor(entry_ix),
                     })
                     .into_any_element()
             }),
@@ -482,24 +497,25 @@ impl AcpThreadView {
 
     fn render_tool_call_confirmation(
         &self,
+        entry_ix: usize,
         tool_call_id: ToolCallId,
         confirmation: &ToolCallConfirmation,
+        window: &Window,
         cx: &Context<Self>,
     ) -> AnyElement {
         match confirmation {
             ToolCallConfirmation::Edit {
-                file_name,
-                file_diff,
                 description,
+                diff: _,
             } => v_flex()
                 .border_color(cx.theme().colors().border)
                 .border_t_1()
                 .px_2()
                 .py_1p5()
-                // todo! nicer rendering
-                .child(file_name.clone())
-                .child(file_diff.clone())
-                .children(description.clone())
+                .child(self.render_diff_editor(entry_ix))
+                .children(description.clone().map(|description| {
+                    MarkdownElement::new(description, default_markdown_style(window, cx))
+                }))
                 .child(
                     h_flex()
                         .justify_end()
@@ -571,7 +587,9 @@ impl AcpThreadView {
                 .py_1p5()
                 // todo! nicer rendering
                 .child(command.clone())
-                .children(description.clone())
+                .children(description.clone().map(|description| {
+                    MarkdownElement::new(description, default_markdown_style(window, cx))
+                }))
                 .child(
                     h_flex()
                         .justify_end()
@@ -644,7 +662,9 @@ impl AcpThreadView {
                 .py_1p5()
                 // todo! nicer rendering
                 .child(format!("{server_name} - {tool_display_name}"))
-                .children(description.clone())
+                .children(description.clone().map(|description| {
+                    MarkdownElement::new(description, default_markdown_style(window, cx))
+                }))
                 .child(
                     h_flex()
                         .justify_end()
@@ -732,7 +752,9 @@ impl AcpThreadView {
                 .py_1p5()
                 // todo! nicer rendering
                 .children(urls.clone())
-                .children(description.clone())
+                .children(description.clone().map(|description| {
+                    MarkdownElement::new(description, default_markdown_style(window, cx))
+                }))
                 .child(
                     h_flex()
                         .justify_end()
@@ -796,7 +818,10 @@ impl AcpThreadView {
                 .px_2()
                 .py_1p5()
                 // todo! nicer rendering
-                .child(description.clone())
+                .child(MarkdownElement::new(
+                    description.clone(),
+                    default_markdown_style(window, cx),
+                ))
                 .child(
                     h_flex()
                         .justify_end()
@@ -856,6 +881,15 @@ impl AcpThreadView {
                 .into_any(),
         }
     }
+
+    fn render_diff_editor(&self, entry_ix: usize) -> AnyElement {
+        if let Some(Some(ThreadEntryView::Diff { editor })) = self.thread_entry_views.get(entry_ix)
+        {
+            editor.clone().into_any_element()
+        } else {
+            Empty.into_any()
+        }
+    }
 }
 
 impl Focusable for AcpThreadView {