From 6fcad1d1332c4fae039ebb4782fa0c13bd6b965d Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 2 Apr 2026 14:51:35 +0200 Subject: [PATCH] cleanup --- Cargo.lock | 2 + Cargo.toml | 2 +- crates/acp_tools/src/acp_tools.rs | 58 ++-- crates/agent_servers/src/acp.rs | 307 ++++++++-------------- crates/agent_servers/src/agent_servers.rs | 38 +-- 5 files changed, 163 insertions(+), 244 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2bc256c9adbe5fec01a2ca50e2c9eb6f1a7d16ed..4f52f444867c7ad96b7cd9609fa7cdcb038cc893 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,6 +220,7 @@ dependencies = [ [[package]] name = "agent-client-protocol-core" version = "0.1.0" +source = "git+https://github.com/agentclientprotocol/rust-sdk?rev=9fa32bb04e301f4be968a372eb5085dc3a86ef3c#9fa32bb04e301f4be968a372eb5085dc3a86ef3c" dependencies = [ "agent-client-protocol-derive", "agent-client-protocol-schema", @@ -243,6 +244,7 @@ dependencies = [ [[package]] name = "agent-client-protocol-derive" version = "0.1.0" +source = "git+https://github.com/agentclientprotocol/rust-sdk?rev=9fa32bb04e301f4be968a372eb5085dc3a86ef3c#9fa32bb04e301f4be968a372eb5085dc3a86ef3c" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 1c5b124981f998cc548a44a1a0b0c0a7436f066b..2474178be5f7a9512aa989037297aa3daccc14f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -477,7 +477,7 @@ ztracing_macro = { path = "crates/ztracing_macro" } # External crates # -agent-client-protocol-core = { path = "../acp-rust-sdk/src/agent-client-protocol-core", features = ["unstable"] } +agent-client-protocol-core = { git = "https://github.com/agentclientprotocol/rust-sdk", rev = "9fa32bb04e301f4be968a372eb5085dc3a86ef3c", features = ["unstable"] } aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty", rev = "9d9640d4" } any_vec = "0.14" diff --git a/crates/acp_tools/src/acp_tools.rs b/crates/acp_tools/src/acp_tools.rs index debc04d41b04476babcfff42fdc853eaaedcc1e3..201750709e95247b34119d7d6d81dd5adf117143 100644 --- a/crates/acp_tools/src/acp_tools.rs +++ b/crates/acp_tools/src/acp_tools.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, collections::HashSet, fmt::Display, sync::Arc}; +use std::{collections::HashSet, fmt::Display, sync::Arc}; use agent_client_protocol_core::schema as acp; use collections::HashMap; @@ -42,6 +42,7 @@ pub enum StreamMessageContent { }, } +#[derive(Clone)] pub struct StreamMessage { pub direction: StreamMessageDirection, pub message: StreamMessageContent, @@ -108,12 +109,10 @@ impl Global for GlobalAcpConnectionRegistry {} #[derive(Default)] pub struct AcpConnectionRegistry { - active_connection: RefCell>, -} - -struct ActiveConnection { - agent_id: AgentId, - messages_rx: smol::channel::Receiver, + active_agent_id: Option, + generation: u64, + subscribers: Vec>, + _broadcast_task: Option>, } impl AcpConnectionRegistry { @@ -128,17 +127,33 @@ impl AcpConnectionRegistry { } pub fn set_active_connection( - &self, + &mut self, agent_id: AgentId, messages_rx: smol::channel::Receiver, cx: &mut Context, ) { - self.active_connection.replace(Some(ActiveConnection { - agent_id, - messages_rx, + self.active_agent_id = Some(agent_id); + self.generation += 1; + self.subscribers.clear(); + + self._broadcast_task = Some(cx.spawn(async move |this, cx| { + while let Ok(message) = messages_rx.recv().await { + this.update(cx, |this, _cx| { + this.subscribers + .retain(|sender| sender.try_send(message.clone()).is_ok()); + }) + .ok(); + } })); + cx.notify(); } + + pub fn subscribe(&mut self) -> smol::channel::Receiver { + let (sender, receiver) = smol::channel::bounded(4096); + self.subscribers.push(sender); + receiver + } } struct AcpTools { @@ -152,6 +167,7 @@ struct AcpTools { struct WatchedConnection { agent_id: AgentId, + generation: u64, messages: Vec, list_state: ListState, incoming_request_methods: HashMap>, @@ -181,18 +197,25 @@ impl AcpTools { } fn update_connection(&mut self, cx: &mut Context) { - let active_connection = self.connection_registry.read(cx).active_connection.borrow(); - let Some(active_connection) = active_connection.as_ref() else { + let (generation, agent_id) = { + let registry = self.connection_registry.read(cx); + (registry.generation, registry.active_agent_id.clone()) + }; + + let Some(agent_id) = agent_id else { return; }; - if let Some(watched_connection) = self.watched_connection.as_ref() { - if watched_connection.agent_id == active_connection.agent_id { + if let Some(watched) = self.watched_connection.as_ref() { + if watched.generation == generation { return; } } - let messages_rx = active_connection.messages_rx.clone(); + let messages_rx = self + .connection_registry + .update(cx, |registry, _cx| registry.subscribe()); + let task = cx.spawn(async move |this, cx| { while let Ok(message) = messages_rx.recv().await { this.update(cx, |this, cx| { @@ -203,7 +226,8 @@ impl AcpTools { }); self.watched_connection = Some(WatchedConnection { - agent_id: active_connection.agent_id.clone(), + agent_id, + generation, messages: vec![], list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)), incoming_request_methods: HashMap::default(), diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 80a037a7e891466d85adb4f2ce559b4f4f4549e6..5e8b7aa59e189c10c19e54285c64cd527c56c3d7 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -203,6 +203,34 @@ pub async fn connect( const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1; +macro_rules! dispatch_request_handler { + ($dispatch_tx:expr, $handler:expr) => {{ + let dispatch_tx = $dispatch_tx.clone(); + async move |args, responder, _connection| { + dispatch_tx + .unbounded_send(Box::new(move |cx, ctx| { + $handler(args, responder, cx, ctx); + })) + .log_err(); + Ok(()) + } + }}; +} + +macro_rules! dispatch_notification_handler { + ($dispatch_tx:expr, $handler:expr) => {{ + let dispatch_tx = $dispatch_tx.clone(); + async move |notification, _connection| { + dispatch_tx + .unbounded_send(Box::new(move |cx, ctx| { + $handler(notification, cx, ctx); + })) + .log_err(); + Ok(()) + } + }}; +} + impl AcpConnection { pub async fn stdio( agent_id: AgentId, @@ -263,7 +291,7 @@ impl AcpConnection { // Build a tapped transport that intercepts raw JSON-RPC lines for // the ACP logs panel. We replicate the ByteStreams→Lines conversion // manually so we can wrap the stream and sink with inspection. - let (stream_tap_tx, stream_tap_rx) = smol::channel::unbounded::(); + let (stream_tap_tx, stream_tap_rx) = smol::channel::bounded::(4096); let incoming_lines = futures::io::BufReader::new(stdout).lines(); let tapped_incoming = incoming_lines.inspect({ @@ -314,167 +342,40 @@ impl AcpConnection { .name("zed") // --- Request handlers (agent→client) --- .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::RequestPermissionRequest, - responder: agent_client_protocol_core::Responder< - acp::RequestPermissionResponse, - >, - _connection: ConnectionTo< - agent_client_protocol_core::Agent, - >| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_request_permission(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_request_permission), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::WriteTextFileRequest, - responder: agent_client_protocol_core::Responder< - acp::WriteTextFileResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_write_text_file(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_write_text_file), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::ReadTextFileRequest, - responder: agent_client_protocol_core::Responder< - acp::ReadTextFileResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_read_text_file(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_read_text_file), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::CreateTerminalRequest, - responder: agent_client_protocol_core::Responder< - acp::CreateTerminalResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_create_terminal(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_create_terminal), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::KillTerminalRequest, - responder: agent_client_protocol_core::Responder< - acp::KillTerminalResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_kill_terminal(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_kill_terminal), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::ReleaseTerminalRequest, - responder: agent_client_protocol_core::Responder< - acp::ReleaseTerminalResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_release_terminal(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_release_terminal), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::TerminalOutputRequest, - responder: agent_client_protocol_core::Responder< - acp::TerminalOutputResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_terminal_output(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_terminal_output), agent_client_protocol_core::on_receive_request!(), ) .on_receive_request( - { - let dispatch_tx = dispatch_tx.clone(); - async move |args: acp::WaitForTerminalExitRequest, - responder: agent_client_protocol_core::Responder< - acp::WaitForTerminalExitResponse, - >, - _connection| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_wait_for_terminal_exit(args, responder, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_request_handler!(dispatch_tx, handle_wait_for_terminal_exit), agent_client_protocol_core::on_receive_request!(), ) // --- Notification handlers (agent→client) --- .on_receive_notification( - { - let dispatch_tx = dispatch_tx.clone(); - async move |notification: acp::SessionNotification, - _connection: ConnectionTo< - agent_client_protocol_core::Agent, - >| { - dispatch_tx - .unbounded_send(Box::new(move |cx, ctx| { - handle_session_notification(notification, cx, ctx); - })) - .ok(); - Ok(()) - } - }, + dispatch_notification_handler!(dispatch_tx, handle_session_notification), agent_client_protocol_core::on_receive_notification!(), ) .connect_with( @@ -1857,82 +1758,86 @@ fn handle_session_notification( cx: &mut AsyncApp, ctx: &ClientContext, ) { - let sessions = ctx.sessions.borrow(); - let Some(session) = sessions.get(¬ification.session_id) else { - log::warn!( - "Received session notification for unknown session: {:?}", - notification.session_id - ); - return; - }; + let (thread, update_clone) = { + let sessions = ctx.sessions.borrow(); + let Some(session) = sessions.get(¬ification.session_id) else { + log::warn!( + "Received session notification for unknown session: {:?}", + notification.session_id + ); + return; + }; - if let acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate { - current_mode_id, .. - }) = ¬ification.update - { - if let Some(session_modes) = &session.session_modes { - session_modes.borrow_mut().current_mode_id = current_mode_id.clone(); + if let acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate { + current_mode_id, + .. + }) = ¬ification.update + { + if let Some(session_modes) = &session.session_modes { + session_modes.borrow_mut().current_mode_id = current_mode_id.clone(); + } } - } - if let acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate { - config_options, .. - }) = ¬ification.update - { - if let Some(opts) = &session.config_options { - *opts.config_options.borrow_mut() = config_options.clone(); - opts.tx.borrow_mut().send(()).ok(); + if let acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate { + config_options, + .. + }) = ¬ification.update + { + if let Some(opts) = &session.config_options { + *opts.config_options.borrow_mut() = config_options.clone(); + opts.tx.borrow_mut().send(()).ok(); + } } - } - if let acp::SessionUpdate::SessionInfoUpdate(info_update) = ¬ification.update - && let Some(session_list) = ctx.session_list.borrow().as_ref() - { - session_list.send_info_update(notification.session_id.clone(), info_update.clone()); - } + if let acp::SessionUpdate::SessionInfoUpdate(info_update) = ¬ification.update + && let Some(session_list) = ctx.session_list.borrow().as_ref() + { + session_list.send_info_update(notification.session_id.clone(), info_update.clone()); + } - let update_clone = notification.update.clone(); - let thread = session.thread.clone(); + let update_clone = notification.update.clone(); + let thread = session.thread.clone(); - // Pre-handle: if a ToolCall carries terminal_info, create/register a display-only terminal. - if let acp::SessionUpdate::ToolCall(tc) = &update_clone { - if let Some(meta) = &tc.meta { - if let Some(terminal_info) = meta.get("terminal_info") { - if let Some(id_str) = terminal_info.get("terminal_id").and_then(|v| v.as_str()) { - let terminal_id = acp::TerminalId::new(id_str); - let cwd = terminal_info - .get("cwd") - .and_then(|v| v.as_str().map(PathBuf::from)); + // Pre-handle: if a ToolCall carries terminal_info, create/register a display-only terminal. + if let acp::SessionUpdate::ToolCall(tc) = &update_clone { + if let Some(meta) = &tc.meta { + if let Some(terminal_info) = meta.get("terminal_info") { + if let Some(id_str) = terminal_info.get("terminal_id").and_then(|v| v.as_str()) + { + let terminal_id = acp::TerminalId::new(id_str); + let cwd = terminal_info + .get("cwd") + .and_then(|v| v.as_str().map(PathBuf::from)); - let _ = thread.update(cx, |thread, cx| { - let builder = TerminalBuilder::new_display_only( - CursorShape::default(), - AlternateScroll::On, - None, - 0, - cx.background_executor(), - thread.project().read(cx).path_style(cx), - )?; - let lower = cx.new(|cx| builder.subscribe(cx)); - thread.on_terminal_provider_event( - TerminalProviderEvent::Created { - terminal_id, - label: tc.title.clone(), - cwd, - output_byte_limit: None, - terminal: lower, - }, - cx, - ); - anyhow::Ok(()) - }); + let _ = thread.update(cx, |thread, cx| { + let builder = TerminalBuilder::new_display_only( + CursorShape::default(), + AlternateScroll::On, + None, + 0, + cx.background_executor(), + thread.project().read(cx).path_style(cx), + )?; + let lower = cx.new(|cx| builder.subscribe(cx)); + thread.on_terminal_provider_event( + TerminalProviderEvent::Created { + terminal_id, + label: tc.title.clone(), + cwd, + output_byte_limit: None, + terminal: lower, + }, + cx, + ); + anyhow::Ok(()) + }); + } } } } - } - // Drop sessions borrow before updating the thread, which may re-borrow. - drop(sessions); + (thread, update_clone) + }; // Forward the update to the acp_thread as usual. if let Err(err) = thread diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 8ebc1c652d993f71f4d7e38046185dcf485b76a6..9bba14dea302b3803334492373830a82820196c4 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -12,6 +12,9 @@ use http_client::read_no_proxy_from_env; use project::{AgentId, Project, agent_server_store::AgentServerStore}; use acp_thread::AgentConnection; +use agent_client_protocol_core::schema::{ + ModelId, SessionConfigId, SessionConfigValueId, SessionModeId, +}; use anyhow::Result; use gpui::{App, AppContext, Entity, Task}; use settings::SettingsStore; @@ -48,34 +51,19 @@ pub trait AgentServer: Send { fn into_any(self: Rc) -> Rc; - fn default_mode(&self, _cx: &App) -> Option { + fn default_mode(&self, _cx: &App) -> Option { None } - fn set_default_mode( - &self, - _mode_id: Option, - _fs: Arc, - _cx: &mut App, - ) { - } + fn set_default_mode(&self, _mode_id: Option, _fs: Arc, _cx: &mut App) {} - fn default_model(&self, _cx: &App) -> Option { + fn default_model(&self, _cx: &App) -> Option { None } - fn set_default_model( - &self, - _model_id: Option, - _fs: Arc, - _cx: &mut App, - ) { - } + fn set_default_model(&self, _model_id: Option, _fs: Arc, _cx: &mut App) {} - fn favorite_model_ids( - &self, - _cx: &mut App, - ) -> HashSet { + fn favorite_model_ids(&self, _cx: &mut App) -> HashSet { HashSet::default() } @@ -94,16 +82,16 @@ pub trait AgentServer: Send { fn favorite_config_option_value_ids( &self, - _config_id: &agent_client_protocol_core::schema::SessionConfigId, + _config_id: &SessionConfigId, _cx: &mut App, - ) -> HashSet { + ) -> HashSet { HashSet::default() } fn toggle_favorite_config_option_value( &self, - _config_id: agent_client_protocol_core::schema::SessionConfigId, - _value_id: agent_client_protocol_core::schema::SessionConfigValueId, + _config_id: SessionConfigId, + _value_id: SessionConfigValueId, _should_be_favorite: bool, _fs: Arc, _cx: &App, @@ -112,7 +100,7 @@ pub trait AgentServer: Send { fn toggle_favorite_model( &self, - _model_id: agent_client_protocol_core::schema::ModelId, + _model_id: ModelId, _should_be_favorite: bool, _fs: Arc, _cx: &App,