Tool authorization

Agus Zubiaga created

Change summary

Cargo.lock                         |   1 
crates/acp/Cargo.toml              |   1 
crates/acp/src/acp.rs              | 124 ++++++++++++++++++++++++++++---
crates/acp/src/server.rs           |  40 ++++++++++
crates/acp/src/thread_view.rs      |  59 ++++++++++++++-
crates/ui/src/traits/styled_ext.rs |   2 
6 files changed, 206 insertions(+), 21 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -17,6 +17,7 @@ dependencies = [
  "futures 0.3.31",
  "gpui",
  "language",
+ "log",
  "markdown",
  "parking_lot",
  "project",

crates/acp/Cargo.toml 🔗

@@ -26,6 +26,7 @@ editor.workspace = true
 futures.workspace = true
 gpui.workspace = true
 language.workspace = true
+log.workspace = true
 markdown.workspace = true
 parking_lot.workspace = true
 project.workspace = true

crates/acp/src/acp.rs 🔗

@@ -4,12 +4,14 @@ mod thread_view;
 use agentic_coding_protocol::{self as acp, Role};
 use anyhow::Result;
 use chrono::{DateTime, Utc};
+use futures::channel::oneshot;
 use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 use language::LanguageRegistry;
 use markdown::Markdown;
 use project::Project;
-use std::{ops::Range, path::PathBuf, sync::Arc};
+use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 use ui::App;
+use util::{ResultExt, debug_panic};
 
 pub use server::AcpServer;
 pub use thread_view::AcpThreadView;
@@ -112,14 +114,32 @@ impl MessageChunk {
     }
 }
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Debug)]
 pub enum AgentThreadEntryContent {
     Message(Message),
     ReadFile { path: PathBuf, content: String },
+    ToolCall(ToolCall),
 }
 
+#[derive(Debug)]
+pub enum ToolCall {
+    WaitingForConfirmation {
+        id: ToolCallId,
+        tool_name: Entity<Markdown>,
+        description: Entity<Markdown>,
+        respond_tx: oneshot::Sender<bool>,
+    },
+    // todo! Running?
+    Allowed,
+    Rejected,
+}
+
+/// A `ThreadEntryId` that is known to be a ToolCall
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct ToolCallId(ThreadEntryId);
+
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
-pub struct ThreadEntryId(usize);
+pub struct ThreadEntryId(pub u64);
 
 impl ThreadEntryId {
     pub fn post_inc(&mut self) -> Self {
@@ -146,7 +166,7 @@ pub struct AcpThread {
 
 enum AcpThreadEvent {
     NewEntry,
-    LastEntryUpdated,
+    EntryUpdated(usize),
 }
 
 impl EventEmitter<AcpThreadEvent> for AcpThread {}
@@ -184,22 +204,26 @@ impl AcpThread {
         &self.entries
     }
 
-    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
-        self.entries.push(ThreadEntry {
-            id: self.next_entry_id.post_inc(),
-            content: entry,
-        });
-        cx.emit(AcpThreadEvent::NewEntry)
+    pub fn push_entry(
+        &mut self,
+        entry: AgentThreadEntryContent,
+        cx: &mut Context<Self>,
+    ) -> ThreadEntryId {
+        let id = self.next_entry_id.post_inc();
+        self.entries.push(ThreadEntry { id, content: entry });
+        cx.emit(AcpThreadEvent::NewEntry);
+        id
     }
 
     pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
+        let entries_len = self.entries.len();
         if let Some(last_entry) = self.entries.last_mut()
             && let AgentThreadEntryContent::Message(Message {
                 ref mut chunks,
                 role: Role::Assistant,
             }) = last_entry.content
         {
-            cx.emit(AcpThreadEvent::LastEntryUpdated);
+            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 
             if let (
                 Some(MessageChunk::Text { chunk: old_chunk }),
@@ -231,6 +255,74 @@ impl AcpThread {
         );
     }
 
+    pub fn push_tool_call(
+        &mut self,
+        title: String,
+        description: String,
+        respond_tx: oneshot::Sender<bool>,
+        cx: &mut Context<Self>,
+    ) -> ToolCallId {
+        let language_registry = self.project.read(cx).languages().clone();
+
+        let entry_id = self.push_entry(
+            AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
+                // todo! clean up id creation
+                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
+                tool_name: cx.new(|cx| {
+                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
+                }),
+                description: cx.new(|cx| {
+                    Markdown::new(
+                        description.into(),
+                        Some(language_registry.clone()),
+                        None,
+                        cx,
+                    )
+                }),
+                respond_tx,
+            }),
+            cx,
+        );
+
+        ToolCallId(entry_id)
+    }
+
+    pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
+        let Some(entry) = self.entry_mut(id.0) else {
+            return;
+        };
+
+        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
+            debug_panic!("expected ToolCall");
+            return;
+        };
+
+        let new_state = if allowed {
+            ToolCall::Allowed
+        } else {
+            ToolCall::Rejected
+        };
+
+        let call = mem::replace(call, new_state);
+
+        if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
+            respond_tx.send(allowed).log_err();
+        } else {
+            debug_panic!("tried to authorize an already authorized tool call");
+        }
+
+        cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
+    }
+
+    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
+        let entry = self.entries.get_mut(id.0 as usize);
+        debug_assert!(
+            entry.is_some(),
+            "We shouldn't give out ids to entries that don't exist"
+        );
+        entry
+    }
+
     pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
         let agent = self.server.clone();
         let id = self.id.clone();
@@ -303,11 +395,13 @@ mod tests {
             ));
             assert!(
                 thread.entries().iter().any(|entry| {
-                    entry.content
-                        == AgentThreadEntryContent::ReadFile {
-                            path: "/private/tmp/foo".into(),
-                            content: "Lorem ipsum dolor".into(),
+                    match &entry.content {
+                        AgentThreadEntryContent::ReadFile { path, content } => {
+                            path.to_string_lossy().to_string() == "/private/tmp/foo"
+                                && content == "Lorem ipsum dolor"
                         }
+                        _ => false,
+                    }
                 }),
                 "Thread does not contain entry. Actual: {:?}",
                 thread.entries()

crates/acp/src/server.rs 🔗

@@ -1,8 +1,9 @@
-use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId};
+use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId};
 use agentic_coding_protocol as acp;
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
 use collections::HashMap;
+use futures::channel::oneshot;
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 use parking_lot::Mutex;
 use project::Project;
@@ -185,6 +186,31 @@ impl acp::Client for AcpClientDelegate {
     ) -> Result<acp::GlobSearchResponse> {
         todo!()
     }
+
+    async fn request_tool_call(
+        &self,
+        request: acp::RequestToolCallParams,
+    ) -> Result<acp::RequestToolCallResponse> {
+        let (tx, rx) = oneshot::channel();
+
+        let cx = &mut self.cx.clone();
+        let entry_id = cx
+            .update(|cx| {
+                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+                    // todo! tools that don't require confirmation
+                    thread.push_tool_call(request.tool_name, request.description, tx, cx)
+                })
+            })?
+            .context("Failed to update thread")?;
+
+        if dbg!(rx.await)? {
+            Ok(acp::RequestToolCallResponse::Allowed {
+                id: entry_id.into(),
+            })
+        } else {
+            Ok(acp::RequestToolCallResponse::Rejected)
+        }
+    }
 }
 
 impl AcpServer {
@@ -258,3 +284,15 @@ impl From<ThreadId> for acp::ThreadId {
         acp::ThreadId(thread_id.0.to_string())
     }
 }
+
+impl From<acp::ToolCallId> for ToolCallId {
+    fn from(tool_call_id: acp::ToolCallId) -> Self {
+        Self(ThreadEntryId(tool_call_id.0.into()))
+    }
+}
+
+impl From<ToolCallId> for acp::ToolCallId {
+    fn from(tool_call_id: ToolCallId) -> Self {
+        acp::ToolCallId(tool_call_id.0.0)
+    }
+}

crates/acp/src/thread_view.rs 🔗

@@ -13,13 +13,14 @@ use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
 use project::Project;
 use settings::Settings as _;
 use theme::ThemeSettings;
-use ui::Tooltip;
 use ui::prelude::*;
+use ui::{Button, Tooltip};
 use util::ResultExt;
 use zed_actions::agent::Chat;
 
 use crate::{
     AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
+    ToolCall, ToolCallId,
 };
 
 pub struct AcpThreadView {
@@ -100,8 +101,8 @@ impl AcpThreadView {
                                 AcpThreadEvent::NewEntry => {
                                     this.list_state.splice(count..count, 1);
                                 }
-                                AcpThreadEvent::LastEntryUpdated => {
-                                    this.list_state.splice(count - 1..count, 1);
+                                AcpThreadEvent::EntryUpdated(index) => {
+                                    this.list_state.splice(*index..*index + 1, 1);
                                 }
                             }
                             cx.notify();
@@ -149,7 +150,7 @@ impl AcpThreadView {
     fn thread(&self) -> Option<&Entity<AcpThread>> {
         match &self.thread_state {
             ThreadState::Ready { thread, .. } => Some(thread),
-            _ => None,
+            ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
         }
     }
 
@@ -187,6 +188,16 @@ impl AcpThreadView {
         });
     }
 
+    fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
+        let Some(thread) = self.thread() else {
+            return;
+        };
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(id, allowed, cx);
+        });
+        cx.notify();
+    }
+
     fn render_entry(
         &self,
         entry: &ThreadEntry,
@@ -236,6 +247,46 @@ impl AcpThreadView {
                     .child(format!("<Reading file {}>", path.display()))
                     .into_any()
             }
+            AgentThreadEntryContent::ToolCall(tool_call) => match tool_call {
+                ToolCall::WaitingForConfirmation {
+                    id,
+                    tool_name,
+                    description,
+                    ..
+                } => {
+                    let id = *id;
+                    v_flex()
+                        .elevation_1(cx)
+                        .child(MarkdownElement::new(
+                            tool_name.clone(),
+                            default_markdown_style(window, cx),
+                        ))
+                        .child(MarkdownElement::new(
+                            description.clone(),
+                            default_markdown_style(window, cx),
+                        ))
+                        .child(
+                            h_flex()
+                                .child(Button::new(("allow", id.0.0), "Allow").on_click(
+                                    cx.listener({
+                                        move |this, _, _, cx| {
+                                            this.authorize_tool_call(id, true, cx);
+                                        }
+                                    }),
+                                ))
+                                .child(Button::new(("reject", id.0.0), "Reject").on_click(
+                                    cx.listener({
+                                        move |this, _, _, cx| {
+                                            this.authorize_tool_call(id, false, cx);
+                                        }
+                                    }),
+                                )),
+                        )
+                        .into_any()
+                }
+                ToolCall::Allowed => div().child("Allowed!").into_any(),
+                ToolCall::Rejected => div().child("Rejected!").into_any(),
+            },
         }
     }
 }

crates/ui/src/traits/styled_ext.rs 🔗

@@ -39,7 +39,7 @@ pub trait StyledExt: Styled + Sized {
     /// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()`
     ///
     /// Example Elements: Title Bar, Panel, Tab Bar, Editor
-    fn elevation_1(self, cx: &mut App) -> Self {
+    fn elevation_1(self, cx: &App) -> Self {
         elevated(self, cx, ElevationIndex::Surface)
     }