diff --git a/Cargo.lock b/Cargo.lock index c6225699f1c882839624cc493e5c130a2cf4c647..ca1be6fe6dbfbdca985b8696a254bcfc55e0566c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,6 +51,7 @@ name = "acp_tools" version = "0.1.0" dependencies = [ "agent-client-protocol", + "agent_ui", "collections", "gpui", "language", @@ -252,7 +253,6 @@ name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", - "acp_tools", "action_log", "agent-client-protocol", "anyhow", diff --git a/crates/acp_tools/Cargo.toml b/crates/acp_tools/Cargo.toml index 8f14b1f93b32c6df521ea13ebf3f0f73e7ed755c..3774961da23173139e4327c81b94c206f4569c00 100644 --- a/crates/acp_tools/Cargo.toml +++ b/crates/acp_tools/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] agent-client-protocol.workspace = true +agent_ui.workspace = true collections.workspace = true gpui.workspace = true language.workspace= true diff --git a/crates/acp_tools/src/acp_tools.rs b/crates/acp_tools/src/acp_tools.rs index ae8a39c8df4f73ae8be6b748694dbde5d2a0c102..b439324c2d5f99c1a0976fb4e0207df9cb8be86f 100644 --- a/crates/acp_tools/src/acp_tools.rs +++ b/crates/acp_tools/src/acp_tools.rs @@ -1,5 +1,4 @@ use std::{ - cell::RefCell, collections::HashSet, fmt::Display, rc::{Rc, Weak}, @@ -7,17 +6,22 @@ use std::{ }; use agent_client_protocol as acp; +use agent_ui::{Agent, AgentConnectionStore, AgentPanel}; use collections::HashMap; use gpui::{ - App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState, - StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*, + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ListAlignment, ListState, + SharedString, StyleRefinement, Subscription, Task, TextStyleRefinement, WeakEntity, Window, + actions, list, prelude::*, }; use language::LanguageRegistry; use markdown::{CodeBlockRenderer, CopyButtonVisibility, Markdown, MarkdownElement, MarkdownStyle}; use project::{AgentId, Project}; use settings::Settings; use theme_settings::ThemeSettings; -use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*}; +use ui::{ + ContextMenu, CopyButton, DropdownMenu, DropdownStyle, IconPosition, Tooltip, WithScrollbar, + prelude::*, +}; use util::ResultExt as _; use workspace::{ Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, @@ -29,8 +33,17 @@ pub fn init(cx: &mut App) { cx.observe_new( |workspace: &mut Workspace, _window, _cx: &mut Context| { workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| { - let acp_tools = - Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx))); + let connection_store = workspace + .panel::(cx) + .map(|panel| panel.read(cx).connection_store().clone()); + let acp_tools = Box::new(cx.new(|cx| { + AcpTools::new( + workspace.weak_handle(), + workspace.project().clone(), + connection_store.clone(), + cx, + ) + })); workspace.add_item_to_active_pane(acp_tools, None, true, window, cx); }); }, @@ -38,52 +51,16 @@ pub fn init(cx: &mut App) { .detach(); } -struct GlobalAcpConnectionRegistry(Entity); - -impl Global for GlobalAcpConnectionRegistry {} - -#[derive(Default)] -pub struct AcpConnectionRegistry { - active_connection: RefCell>, -} - -struct ActiveConnection { - agent_id: AgentId, - connection: Weak, -} - -impl AcpConnectionRegistry { - pub fn default_global(cx: &mut App) -> Entity { - if cx.has_global::() { - cx.global::().0.clone() - } else { - let registry = cx.new(|_cx| AcpConnectionRegistry::default()); - cx.set_global(GlobalAcpConnectionRegistry(registry.clone())); - registry - } - } - - pub fn set_active_connection( - &self, - agent_id: AgentId, - connection: &Rc, - cx: &mut Context, - ) { - self.active_connection.replace(Some(ActiveConnection { - agent_id, - connection: Rc::downgrade(connection), - })); - cx.notify(); - } -} - struct AcpTools { + workspace: WeakEntity, project: Entity, focus_handle: FocusHandle, expanded: HashSet, - watched_connection: Option, - connection_registry: Entity, - _subscription: Subscription, + watched_connections: HashMap, + selected_connection: Option, + connection_store: Option>, + _workspace_subscription: Option, + _connection_store_subscription: Option, } struct WatchedConnection { @@ -97,66 +74,231 @@ struct WatchedConnection { } impl AcpTools { - fn new(project: Entity, cx: &mut Context) -> Self { - let connection_registry = AcpConnectionRegistry::default_global(cx); - - let subscription = cx.observe(&connection_registry, |this, _, cx| { - this.update_connection(cx); - cx.notify(); + fn new( + workspace: WeakEntity, + project: Entity, + connection_store: Option>, + cx: &mut Context, + ) -> Self { + let workspace_subscription = workspace.upgrade().map(|workspace| { + cx.observe(&workspace, |this, _, cx| { + this.update_connection_store(cx); + }) }); let mut this = Self { + workspace, project, focus_handle: cx.focus_handle(), expanded: HashSet::default(), - watched_connection: None, - connection_registry, - _subscription: subscription, + watched_connections: HashMap::default(), + selected_connection: None, + connection_store: None, + _workspace_subscription: workspace_subscription, + _connection_store_subscription: None, }; - this.update_connection(cx); + this.set_connection_store(connection_store, cx); this } - fn update_connection(&mut self, cx: &mut Context) { - let active_connection = self.connection_registry.read(cx).active_connection.borrow(); - let Some(active_connection) = active_connection.as_ref() else { + fn set_connection_store( + &mut self, + connection_store: Option>, + cx: &mut Context, + ) { + if self.connection_store == connection_store { return; - }; + } - if let Some(watched_connection) = self.watched_connection.as_ref() { - if Weak::ptr_eq( - &watched_connection.connection, - &active_connection.connection, - ) { - return; + self.connection_store = connection_store.clone(); + self._connection_store_subscription = connection_store.as_ref().map(|connection_store| { + cx.observe(connection_store, |this, _, cx| { + this.refresh_connections(cx); + }) + }); + self.refresh_connections(cx); + } + + fn update_connection_store(&mut self, cx: &mut Context) { + let connection_store = self.workspace.upgrade().and_then(|workspace| { + workspace + .read(cx) + .panel::(cx) + .map(|panel| panel.read(cx).connection_store().clone()) + }); + + self.set_connection_store(connection_store, cx); + } + + fn refresh_connections(&mut self, cx: &mut Context) { + let mut did_change = false; + let active_connections = self + .connection_store + .as_ref() + .map(|connection_store| connection_store.read(cx).active_acp_connections(cx)) + .unwrap_or_default(); + + self.watched_connections + .retain(|agent_id, watched_connection| { + let should_retain = active_connections.iter().any(|active_connection| { + &active_connection.agent_id == agent_id + && Rc::ptr_eq( + &active_connection.connection, + &watched_connection + .connection + .upgrade() + .unwrap_or_else(|| active_connection.connection.clone()), + ) + }); + + if !should_retain { + did_change = true; + } + + should_retain + }); + + for active_connection in active_connections { + let should_create_watcher = self + .watched_connections + .get(&active_connection.agent_id) + .is_none_or(|watched_connection| { + watched_connection + .connection + .upgrade() + .is_none_or(|connection| { + !Rc::ptr_eq(&connection, &active_connection.connection) + }) + }); + + if !should_create_watcher { + continue; } - } - if let Some(connection) = active_connection.connection.upgrade() { + let agent_id = active_connection.agent_id.clone(); + let connection = active_connection.connection; let mut receiver = connection.subscribe(); - let task = cx.spawn(async move |this, cx| { - while let Ok(message) = receiver.recv().await { - this.update(cx, |this, cx| { - this.push_stream_message(message, cx); - }) - .ok(); + let task = cx.spawn({ + let agent_id = agent_id.clone(); + async move |this, cx| { + while let Ok(message) = receiver.recv().await { + this.update(cx, |this, cx| { + this.push_stream_message(&agent_id, message, cx); + }) + .ok(); + } } }); - self.watched_connection = Some(WatchedConnection { - agent_id: active_connection.agent_id.clone(), - messages: vec![], - list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)), - connection: active_connection.connection.clone(), - incoming_request_methods: HashMap::default(), - outgoing_request_methods: HashMap::default(), - _task: task, - }); + self.watched_connections.insert( + agent_id.clone(), + WatchedConnection { + agent_id, + messages: vec![], + list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)), + connection: Rc::downgrade(&connection), + incoming_request_methods: HashMap::default(), + outgoing_request_methods: HashMap::default(), + _task: task, + }, + ); + did_change = true; + } + + let previous_selected_connection = self.selected_connection.clone(); + self.selected_connection = self + .selected_connection + .clone() + .filter(|agent_id| self.watched_connections.contains_key(agent_id)) + .or_else(|| self.watched_connections.keys().next().cloned()); + + if self.selected_connection != previous_selected_connection { + self.expanded.clear(); + did_change = true; + } + + if did_change { + cx.notify(); } } - fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context) { - let Some(connection) = self.watched_connection.as_mut() else { + fn select_connection(&mut self, agent_id: Option, cx: &mut Context) { + if self.selected_connection == agent_id { + return; + } + + self.selected_connection = agent_id; + self.expanded.clear(); + cx.notify(); + } + + fn restart_selected_connection(&mut self, cx: &mut Context) { + let Some(agent_id) = self.selected_connection.clone() else { + return; + }; + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + + workspace.update(cx, |workspace, cx| { + let Some(panel) = workspace.panel::(cx) else { + return; + }; + + let fs = workspace.app_state().fs.clone(); + let (thread_store, connection_store) = { + let panel = panel.read(cx); + ( + panel.thread_store().clone(), + panel.connection_store().clone(), + ) + }; + let agent = Agent::from(agent_id.clone()); + let server = agent.server(fs, thread_store); + + connection_store.update(cx, |store, cx| { + store.restart_connection(agent, server, cx); + }); + }); + } + + fn selected_connection_status( + &self, + cx: &App, + ) -> Option { + let agent_id = self.selected_connection.clone()?; + let connection_store = self.connection_store.as_ref()?; + let agent = Agent::from(agent_id); + Some(connection_store.read(cx).connection_status(&agent, cx)) + } + + fn selected_watched_connection(&self) -> Option<&WatchedConnection> { + let selected_connection = self.selected_connection.as_ref()?; + self.watched_connections.get(selected_connection) + } + + fn selected_watched_connection_mut(&mut self) -> Option<&mut WatchedConnection> { + let selected_connection = self.selected_connection.clone()?; + self.watched_connections.get_mut(&selected_connection) + } + + fn connection_menu_entries(&self) -> Vec { + let mut entries: Vec<_> = self + .watched_connections + .values() + .map(|connection| SharedString::from(connection.agent_id.0.clone())) + .collect(); + entries.sort(); + entries + } + + fn push_stream_message( + &mut self, + agent_id: &AgentId, + stream_message: acp::StreamMessage, + cx: &mut Context, + ) { + let Some(connection) = self.watched_connections.get_mut(agent_id) else { return; }; let language_registry = self.project.read(cx).languages().clone(); @@ -230,7 +372,7 @@ impl AcpTools { } fn serialize_observed_messages(&self) -> Option { - let connection = self.watched_connection.as_ref()?; + let connection = self.selected_watched_connection()?; let messages: Vec = connection .messages @@ -258,7 +400,7 @@ impl AcpTools { } fn clear_messages(&mut self, cx: &mut Context) { - if let Some(connection) = self.watched_connection.as_mut() { + if let Some(connection) = self.selected_watched_connection_mut() { connection.messages.clear(); connection.list_state.reset(0); self.expanded.clear(); @@ -266,13 +408,55 @@ impl AcpTools { } } + fn selected_connection_label(&self) -> SharedString { + self.selected_connection + .as_ref() + .map(|agent_id| SharedString::from(agent_id.0.clone())) + .unwrap_or_else(|| SharedString::from("No connection selected")) + } + + fn connection_menu(&self, window: &mut Window, cx: &mut Context) -> Entity { + let entries = self.connection_menu_entries(); + let selected_connection = self.selected_connection.clone(); + let acp_tools = cx.entity().downgrade(); + + ContextMenu::build(window, cx, move |mut menu, _window, _cx| { + if entries.is_empty() { + return menu.entry("No active connections", None, |_, _| {}); + } + + for entry in &entries { + let label = entry.clone(); + let selected = selected_connection + .as_ref() + .is_some_and(|agent_id| agent_id.0.as_ref() == label.as_ref()); + let weak_acp_tools = acp_tools.clone(); + menu = menu.toggleable_entry( + label.clone(), + selected, + IconPosition::Start, + None, + move |_window, cx| { + weak_acp_tools + .update(cx, |this, cx| { + this.select_connection(Some(AgentId(label.clone())), cx); + }) + .ok(); + }, + ); + } + + menu + }) + } + fn render_message( &mut self, index: usize, window: &mut Window, cx: &mut Context, ) -> AnyElement { - let Some(connection) = self.watched_connection.as_ref() else { + let Some(connection) = self.selected_watched_connection() else { return Empty.into_any(); }; @@ -314,13 +498,14 @@ impl AcpTools { this.expanded.remove(&index); } else { this.expanded.insert(index); - let Some(connection) = &mut this.watched_connection else { + let project = this.project.clone(); + let Some(connection) = this.selected_watched_connection_mut() else { return; }; let Some(message) = connection.messages.get_mut(index) else { return; }; - message.expanded(this.project.read(cx).languages().clone(), cx); + message.expanded(project.read(cx).languages().clone(), cx); connection.list_state.scroll_to_reveal_item(index); } cx.notify() @@ -485,8 +670,7 @@ impl Item for AcpTools { fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString { format!( "ACP: {}", - self.watched_connection - .as_ref() + self.selected_watched_connection() .map_or("Disconnected", |connection| connection.agent_id.0.as_ref()) ) .into() @@ -509,7 +693,23 @@ impl Render for AcpTools { .track_focus(&self.focus_handle) .size_full() .bg(cx.theme().colors().editor_background) - .child(match self.watched_connection.as_ref() { + .child( + h_flex() + .px_3() + .py_2() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + DropdownMenu::new( + "acp-connection-selector", + self.selected_connection_label(), + self.connection_menu(window, cx), + ) + .style(DropdownStyle::Subtle) + .disabled(self.watched_connections.is_empty()), + ), + ) + .child(match self.selected_watched_connection() { Some(connection) => { if connection.messages.is_empty() { h_flex() @@ -561,14 +761,33 @@ impl Render for AcpToolsToolbarItemView { }; let acp_tools = acp_tools.clone(); - let has_messages = acp_tools - .read(cx) - .watched_connection - .as_ref() - .is_some_and(|connection| !connection.messages.is_empty()); + let (has_messages, can_restart) = { + let acp_tools = acp_tools.read(cx); + ( + acp_tools + .selected_watched_connection() + .is_some_and(|connection| !connection.messages.is_empty()), + acp_tools.selected_connection_status(cx) + != Some(agent_ui::agent_connection_store::AgentConnectionStatus::Connecting), + ) + }; h_flex() .gap_2() + .child( + IconButton::new("restart_connection", IconName::RotateCw) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Restart Connection")) + .disabled(!can_restart) + .on_click(cx.listener({ + let acp_tools = acp_tools.clone(); + move |_this, _, _window, cx| { + acp_tools.update(cx, |acp_tools, cx| { + acp_tools.restart_selected_connection(cx); + }); + } + })), + ) .child({ let message = acp_tools .read(cx) diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 7151f0084b1cb7d9b206f57551ce715ef67483f7..87c276eccceb0b17db5844f3800d12704593e57e 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -17,7 +17,6 @@ path = "src/agent_servers.rs" doctest = false [dependencies] -acp_tools.workspace = true acp_thread.workspace = true action_log.workspace = true agent-client-protocol.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index e56db9df927ab3cdf838587f1cb4f9514eb5a758..b61e75a8e95421dffa59081a826f9de7b7c3756a 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -2,7 +2,7 @@ use acp_thread::{ AgentConnection, AgentSessionInfo, AgentSessionList, AgentSessionListRequest, AgentSessionListResponse, }; -use acp_tools::AcpConnectionRegistry; + use action_log::ActionLog; use agent_client_protocol::{self as acp, Agent as _, ErrorCode}; use anyhow::anyhow; @@ -188,6 +188,10 @@ pub async fn connect( const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1; impl AcpConnection { + pub fn client_connection(&self) -> &Rc { + &self.connection + } + pub async fn stdio( agent_id: AgentId, project: Entity, @@ -287,12 +291,6 @@ impl AcpConnection { let connection = Rc::new(connection); - cx.update(|cx| { - AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| { - registry.set_active_connection(agent_id.clone(), &connection, cx) - }); - }); - let response = connection .initialize( acp::InitializeRequest::new(acp::ProtocolVersion::V1) diff --git a/crates/agent_ui/src/agent_connection_store.rs b/crates/agent_ui/src/agent_connection_store.rs index f19a2aa2626d4cbf1fc1ddd0878c5c029d403818..6fa13762cce022b066413020ecd8c86e6a9aa305 100644 --- a/crates/agent_ui/src/agent_connection_store.rs +++ b/crates/agent_ui/src/agent_connection_store.rs @@ -1,6 +1,7 @@ use std::rc::Rc; use acp_thread::{AgentConnection, LoadError}; +use agent_client_protocol as acp; use agent_servers::{AgentServer, AgentServerDelegate}; use anyhow::Result; use collections::HashMap; @@ -66,6 +67,12 @@ pub enum AgentConnectionEntryEvent { impl EventEmitter for AgentConnectionEntry {} +#[derive(Clone)] +pub struct ActiveAcpConnection { + pub agent_id: project::AgentId, + pub connection: Rc, +} + pub struct AgentConnectionStore { project: Entity, entries: HashMap>, @@ -98,6 +105,25 @@ impl AgentConnectionStore { .unwrap_or(AgentConnectionStatus::Disconnected) } + pub fn active_acp_connections(&self, cx: &App) -> Vec { + self.entries + .values() + .filter_map(|entry| match entry.read(cx) { + AgentConnectionEntry::Connected(state) => state + .connection + .clone() + .downcast::() + .map(|connection| ActiveAcpConnection { + agent_id: state.connection.agent_id(), + connection: connection.client_connection().clone(), + }), + AgentConnectionEntry::Connecting { .. } | AgentConnectionEntry::Error { .. } => { + None + } + }) + .collect() + } + pub fn restart_connection( &mut self, key: Agent, @@ -152,6 +178,7 @@ impl AgentConnectionStore { } }) .ok(); + cx.notify(); }) .ok(); } @@ -217,10 +244,14 @@ impl AgentConnectionStore { _: &AgentServersUpdated, cx: &mut Context, ) { - let store = store.read(cx); + let external_agent_ids: Vec<_> = { + let store = store.read(cx); + store.external_agents.keys().cloned().collect() + }; + self.entries.retain(|key, _| match key { Agent::NativeAgent => true, - Agent::Custom { id } => store.external_agents.contains_key(id), + Agent::Custom { id } => external_agent_ids.contains(id), }); cx.notify(); } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 291a5ff0d9da2c2c48f705358e159bc0cbfe7fcb..09747b717449b451862c4ba57a2b7cff2f74cc4f 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -2456,6 +2456,11 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { + if self.selected_agent != agent { + self.selected_agent = agent.clone(); + self.serialize(cx); + } + if let Some(store) = ThreadMetadataStore::try_global(cx) { store.update(cx, |store, cx| store.unarchive(&session_id, cx)); } @@ -3585,7 +3590,7 @@ impl AgentPanel { .action("Toggle Threads Sidebar", Box::new(ToggleWorkspaceSidebar)); if has_auth_methods { - menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent)) + menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent)); } menu diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 037821d56d00851100488d68b0b44cee0aecbd53..03411ca6b1ae693767f98b523fdd9c3189cf7c04 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -1,5 +1,5 @@ mod agent_configuration; -pub(crate) mod agent_connection_store; +pub mod agent_connection_store; mod agent_diff; mod agent_model_selector; mod agent_panel; @@ -64,6 +64,7 @@ use std::any::TypeId; use workspace::Workspace; use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal}; +pub use crate::agent_connection_store::{ActiveAcpConnection, AgentConnectionStore}; pub use crate::agent_panel::{AgentPanel, AgentPanelEvent, WorktreeCreationStatus}; use crate::agent_registry_ui::AgentRegistryPage; pub use crate::inline_assistant::InlineAssistant; diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index 80190858151b2cf79500290a95ee0d0b6a4e8c97..57ad644e2cd4b426fa572bd937930040c6acbf4f 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -343,6 +343,10 @@ impl ConversationView { } } + pub fn connection_key(&self) -> &Agent { + &self.connection_key + } + pub fn pending_tool_call<'a>( &'a self, cx: &'a App,