From 3ceeefe46056daab32f8f2ae308c6a4bc718f27a Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 1 Jul 2025 20:32:21 -0300 Subject: [PATCH] Tool authorization --- Cargo.lock | 1 + crates/acp/Cargo.toml | 1 + crates/acp/src/acp.rs | 124 +++++++++++++++++++++++++---- crates/acp/src/server.rs | 40 +++++++++- crates/acp/src/thread_view.rs | 59 +++++++++++++- crates/ui/src/traits/styled_ext.rs | 2 +- 6 files changed, 206 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e076350c08ecfaeca23a64bfd53479e1ae963cff..d4e7d08dee3c88b763f990d29e3e258eb437d68d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,7 @@ dependencies = [ "futures 0.3.31", "gpui", "language", + "log", "markdown", "parking_lot", "project", diff --git a/crates/acp/Cargo.toml b/crates/acp/Cargo.toml index 3d85c3bd4239285c487ebabd7356846f528ee18e..2b81dc03e997aaac44feb87fc7c8b65ee19a59f2 100644 --- a/crates/acp/Cargo.toml +++ b/crates/acp/Cargo.toml @@ -26,6 +26,7 @@ editor.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true +log.workspace = true markdown.workspace = true parking_lot.workspace = true project.workspace = true diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index 77f4097a728e8ee25d9fa977686f5eff453867ab..b28e204c89eb5b1ef0998e818830a1ef6bf84641 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -4,12 +4,14 @@ mod thread_view; use agentic_coding_protocol::{self as acp, Role}; use anyhow::Result; use chrono::{DateTime, Utc}; +use futures::channel::oneshot; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use language::LanguageRegistry; use markdown::Markdown; use project::Project; -use std::{ops::Range, path::PathBuf, sync::Arc}; +use std::{mem, ops::Range, path::PathBuf, sync::Arc}; use ui::App; +use util::{ResultExt, debug_panic}; pub use server::AcpServer; pub use thread_view::AcpThreadView; @@ -112,14 +114,32 @@ impl MessageChunk { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] pub enum AgentThreadEntryContent { Message(Message), ReadFile { path: PathBuf, content: String }, + ToolCall(ToolCall), } +#[derive(Debug)] +pub enum ToolCall { + WaitingForConfirmation { + id: ToolCallId, + tool_name: Entity, + description: Entity, + respond_tx: oneshot::Sender, + }, + // todo! Running? + Allowed, + Rejected, +} + +/// A `ThreadEntryId` that is known to be a ToolCall +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ToolCallId(ThreadEntryId); + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct ThreadEntryId(usize); +pub struct ThreadEntryId(pub u64); impl ThreadEntryId { pub fn post_inc(&mut self) -> Self { @@ -146,7 +166,7 @@ pub struct AcpThread { enum AcpThreadEvent { NewEntry, - LastEntryUpdated, + EntryUpdated(usize), } impl EventEmitter for AcpThread {} @@ -184,22 +204,26 @@ impl AcpThread { &self.entries } - pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context) { - self.entries.push(ThreadEntry { - id: self.next_entry_id.post_inc(), - content: entry, - }); - cx.emit(AcpThreadEvent::NewEntry) + pub fn push_entry( + &mut self, + entry: AgentThreadEntryContent, + cx: &mut Context, + ) -> ThreadEntryId { + let id = self.next_entry_id.post_inc(); + self.entries.push(ThreadEntry { id, content: entry }); + cx.emit(AcpThreadEvent::NewEntry); + id } pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context) { + let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntryContent::Message(Message { ref mut chunks, role: Role::Assistant, }) = last_entry.content { - cx.emit(AcpThreadEvent::LastEntryUpdated); + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); if let ( Some(MessageChunk::Text { chunk: old_chunk }), @@ -231,6 +255,74 @@ impl AcpThread { ); } + pub fn push_tool_call( + &mut self, + title: String, + description: String, + respond_tx: oneshot::Sender, + cx: &mut Context, + ) -> ToolCallId { + let language_registry = self.project.read(cx).languages().clone(); + + let entry_id = self.push_entry( + AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation { + // todo! clean up id creation + id: ToolCallId(ThreadEntryId(self.entries.len() as u64)), + tool_name: cx.new(|cx| { + Markdown::new(title.into(), Some(language_registry.clone()), None, cx) + }), + description: cx.new(|cx| { + Markdown::new( + description.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + respond_tx, + }), + cx, + ); + + ToolCallId(entry_id) + } + + pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context) { + let Some(entry) = self.entry_mut(id.0) else { + return; + }; + + let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else { + debug_panic!("expected ToolCall"); + return; + }; + + let new_state = if allowed { + ToolCall::Allowed + } else { + ToolCall::Rejected + }; + + let call = mem::replace(call, new_state); + + if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call { + respond_tx.send(allowed).log_err(); + } else { + debug_panic!("tried to authorize an already authorized tool call"); + } + + cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize)); + } + + fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> { + let entry = self.entries.get_mut(id.0 as usize); + debug_assert!( + entry.is_some(), + "We shouldn't give out ids to entries that don't exist" + ); + entry + } + pub fn send(&mut self, message: &str, cx: &mut Context) -> Task> { let agent = self.server.clone(); let id = self.id.clone(); @@ -303,11 +395,13 @@ mod tests { )); assert!( thread.entries().iter().any(|entry| { - entry.content - == AgentThreadEntryContent::ReadFile { - path: "/private/tmp/foo".into(), - content: "Lorem ipsum dolor".into(), + match &entry.content { + AgentThreadEntryContent::ReadFile { path, content } => { + path.to_string_lossy().to_string() == "/private/tmp/foo" + && content == "Lorem ipsum dolor" } + _ => false, + } }), "Thread does not contain entry. Actual: {:?}", thread.entries() diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index 323f6bf2d0f496f4fa322b73ca240b28c4196d0b..44b5acc3e6569253c77080ac0bd0f546b788014e 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -1,8 +1,9 @@ -use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId}; +use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId}; use agentic_coding_protocol as acp; use anyhow::{Context as _, Result}; use async_trait::async_trait; use collections::HashMap; +use futures::channel::oneshot; use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use parking_lot::Mutex; use project::Project; @@ -185,6 +186,31 @@ impl acp::Client for AcpClientDelegate { ) -> Result { todo!() } + + async fn request_tool_call( + &self, + request: acp::RequestToolCallParams, + ) -> Result { + let (tx, rx) = oneshot::channel(); + + let cx = &mut self.cx.clone(); + let entry_id = cx + .update(|cx| { + self.update_thread(&request.thread_id.into(), cx, |thread, cx| { + // todo! tools that don't require confirmation + thread.push_tool_call(request.tool_name, request.description, tx, cx) + }) + })? + .context("Failed to update thread")?; + + if dbg!(rx.await)? { + Ok(acp::RequestToolCallResponse::Allowed { + id: entry_id.into(), + }) + } else { + Ok(acp::RequestToolCallResponse::Rejected) + } + } } impl AcpServer { @@ -258,3 +284,15 @@ impl From for acp::ThreadId { acp::ThreadId(thread_id.0.to_string()) } } + +impl From for ToolCallId { + fn from(tool_call_id: acp::ToolCallId) -> Self { + Self(ThreadEntryId(tool_call_id.0.into())) + } +} + +impl From for acp::ToolCallId { + fn from(tool_call_id: ToolCallId) -> Self { + acp::ToolCallId(tool_call_id.0.0) + } +} diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index 7ad76caabb160da84437a9c101a8b3884612aad4..cddefeb9647a8223070253098db8560c8aa4f8a1 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -13,13 +13,14 @@ use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle}; use project::Project; use settings::Settings as _; use theme::ThemeSettings; -use ui::Tooltip; use ui::prelude::*; +use ui::{Button, Tooltip}; use util::ResultExt; use zed_actions::agent::Chat; use crate::{ AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry, + ToolCall, ToolCallId, }; pub struct AcpThreadView { @@ -100,8 +101,8 @@ impl AcpThreadView { AcpThreadEvent::NewEntry => { this.list_state.splice(count..count, 1); } - AcpThreadEvent::LastEntryUpdated => { - this.list_state.splice(count - 1..count, 1); + AcpThreadEvent::EntryUpdated(index) => { + this.list_state.splice(*index..*index + 1, 1); } } cx.notify(); @@ -149,7 +150,7 @@ impl AcpThreadView { fn thread(&self) -> Option<&Entity> { match &self.thread_state { ThreadState::Ready { thread, .. } => Some(thread), - _ => None, + ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, } } @@ -187,6 +188,16 @@ impl AcpThreadView { }); } + fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context) { + let Some(thread) = self.thread() else { + return; + }; + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(id, allowed, cx); + }); + cx.notify(); + } + fn render_entry( &self, entry: &ThreadEntry, @@ -236,6 +247,46 @@ impl AcpThreadView { .child(format!("", path.display())) .into_any() } + AgentThreadEntryContent::ToolCall(tool_call) => match tool_call { + ToolCall::WaitingForConfirmation { + id, + tool_name, + description, + .. + } => { + let id = *id; + v_flex() + .elevation_1(cx) + .child(MarkdownElement::new( + tool_name.clone(), + default_markdown_style(window, cx), + )) + .child(MarkdownElement::new( + description.clone(), + default_markdown_style(window, cx), + )) + .child( + h_flex() + .child(Button::new(("allow", id.0.0), "Allow").on_click( + cx.listener({ + move |this, _, _, cx| { + this.authorize_tool_call(id, true, cx); + } + }), + )) + .child(Button::new(("reject", id.0.0), "Reject").on_click( + cx.listener({ + move |this, _, _, cx| { + this.authorize_tool_call(id, false, cx); + } + }), + )), + ) + .into_any() + } + ToolCall::Allowed => div().child("Allowed!").into_any(), + ToolCall::Rejected => div().child("Rejected!").into_any(), + }, } } } diff --git a/crates/ui/src/traits/styled_ext.rs b/crates/ui/src/traits/styled_ext.rs index 63926070c8c6a9c81e758ffb0bf6fa9ba3d87874..cf452a2826e75bd88910b605a90fe34aa0ea62bd 100644 --- a/crates/ui/src/traits/styled_ext.rs +++ b/crates/ui/src/traits/styled_ext.rs @@ -39,7 +39,7 @@ pub trait StyledExt: Styled + Sized { /// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()` /// /// Example Elements: Title Bar, Panel, Tab Bar, Editor - fn elevation_1(self, cx: &mut App) -> Self { + fn elevation_1(self, cx: &App) -> Self { elevated(self, cx, ElevationIndex::Surface) }