@@ -1,5 +1,4 @@
use std::{
- cell::RefCell,
collections::HashSet,
fmt::Display,
rc::{Rc, Weak},
@@ -7,17 +6,22 @@ use std::{
};
use agent_client_protocol as acp;
+use agent_ui::{Agent, AgentConnectionStore, AgentPanel};
use collections::HashMap;
use gpui::{
- App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
- StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
+ App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ListAlignment, ListState,
+ SharedString, StyleRefinement, Subscription, Task, TextStyleRefinement, WeakEntity, Window,
+ actions, list, prelude::*,
};
use language::LanguageRegistry;
use markdown::{CodeBlockRenderer, CopyButtonVisibility, Markdown, MarkdownElement, MarkdownStyle};
use project::{AgentId, Project};
use settings::Settings;
use theme_settings::ThemeSettings;
-use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
+use ui::{
+ ContextMenu, CopyButton, DropdownMenu, DropdownStyle, IconPosition, Tooltip, WithScrollbar,
+ prelude::*,
+};
use util::ResultExt as _;
use workspace::{
Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
@@ -29,8 +33,17 @@ pub fn init(cx: &mut App) {
cx.observe_new(
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
- let acp_tools =
- Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
+ let connection_store = workspace
+ .panel::<AgentPanel>(cx)
+ .map(|panel| panel.read(cx).connection_store().clone());
+ let acp_tools = Box::new(cx.new(|cx| {
+ AcpTools::new(
+ workspace.weak_handle(),
+ workspace.project().clone(),
+ connection_store.clone(),
+ cx,
+ )
+ }));
workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
});
},
@@ -38,52 +51,16 @@ pub fn init(cx: &mut App) {
.detach();
}
-struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
-
-impl Global for GlobalAcpConnectionRegistry {}
-
-#[derive(Default)]
-pub struct AcpConnectionRegistry {
- active_connection: RefCell<Option<ActiveConnection>>,
-}
-
-struct ActiveConnection {
- agent_id: AgentId,
- connection: Weak<acp::ClientSideConnection>,
-}
-
-impl AcpConnectionRegistry {
- pub fn default_global(cx: &mut App) -> Entity<Self> {
- if cx.has_global::<GlobalAcpConnectionRegistry>() {
- cx.global::<GlobalAcpConnectionRegistry>().0.clone()
- } else {
- let registry = cx.new(|_cx| AcpConnectionRegistry::default());
- cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
- registry
- }
- }
-
- pub fn set_active_connection(
- &self,
- agent_id: AgentId,
- connection: &Rc<acp::ClientSideConnection>,
- cx: &mut Context<Self>,
- ) {
- self.active_connection.replace(Some(ActiveConnection {
- agent_id,
- connection: Rc::downgrade(connection),
- }));
- cx.notify();
- }
-}
-
struct AcpTools {
+ workspace: WeakEntity<Workspace>,
project: Entity<Project>,
focus_handle: FocusHandle,
expanded: HashSet<usize>,
- watched_connection: Option<WatchedConnection>,
- connection_registry: Entity<AcpConnectionRegistry>,
- _subscription: Subscription,
+ watched_connections: HashMap<AgentId, WatchedConnection>,
+ selected_connection: Option<AgentId>,
+ connection_store: Option<Entity<AgentConnectionStore>>,
+ _workspace_subscription: Option<Subscription>,
+ _connection_store_subscription: Option<Subscription>,
}
struct WatchedConnection {
@@ -97,66 +74,231 @@ struct WatchedConnection {
}
impl AcpTools {
- fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
- let connection_registry = AcpConnectionRegistry::default_global(cx);
-
- let subscription = cx.observe(&connection_registry, |this, _, cx| {
- this.update_connection(cx);
- cx.notify();
+ fn new(
+ workspace: WeakEntity<Workspace>,
+ project: Entity<Project>,
+ connection_store: Option<Entity<AgentConnectionStore>>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let workspace_subscription = workspace.upgrade().map(|workspace| {
+ cx.observe(&workspace, |this, _, cx| {
+ this.update_connection_store(cx);
+ })
});
let mut this = Self {
+ workspace,
project,
focus_handle: cx.focus_handle(),
expanded: HashSet::default(),
- watched_connection: None,
- connection_registry,
- _subscription: subscription,
+ watched_connections: HashMap::default(),
+ selected_connection: None,
+ connection_store: None,
+ _workspace_subscription: workspace_subscription,
+ _connection_store_subscription: None,
};
- this.update_connection(cx);
+ this.set_connection_store(connection_store, cx);
this
}
- fn update_connection(&mut self, cx: &mut Context<Self>) {
- let active_connection = self.connection_registry.read(cx).active_connection.borrow();
- let Some(active_connection) = active_connection.as_ref() else {
+ fn set_connection_store(
+ &mut self,
+ connection_store: Option<Entity<AgentConnectionStore>>,
+ cx: &mut Context<Self>,
+ ) {
+ if self.connection_store == connection_store {
return;
- };
+ }
- if let Some(watched_connection) = self.watched_connection.as_ref() {
- if Weak::ptr_eq(
- &watched_connection.connection,
- &active_connection.connection,
- ) {
- return;
+ self.connection_store = connection_store.clone();
+ self._connection_store_subscription = connection_store.as_ref().map(|connection_store| {
+ cx.observe(connection_store, |this, _, cx| {
+ this.refresh_connections(cx);
+ })
+ });
+ self.refresh_connections(cx);
+ }
+
+ fn update_connection_store(&mut self, cx: &mut Context<Self>) {
+ let connection_store = self.workspace.upgrade().and_then(|workspace| {
+ workspace
+ .read(cx)
+ .panel::<AgentPanel>(cx)
+ .map(|panel| panel.read(cx).connection_store().clone())
+ });
+
+ self.set_connection_store(connection_store, cx);
+ }
+
+ fn refresh_connections(&mut self, cx: &mut Context<Self>) {
+ let mut did_change = false;
+ let active_connections = self
+ .connection_store
+ .as_ref()
+ .map(|connection_store| connection_store.read(cx).active_acp_connections(cx))
+ .unwrap_or_default();
+
+ self.watched_connections
+ .retain(|agent_id, watched_connection| {
+ let should_retain = active_connections.iter().any(|active_connection| {
+ &active_connection.agent_id == agent_id
+ && Rc::ptr_eq(
+ &active_connection.connection,
+ &watched_connection
+ .connection
+ .upgrade()
+ .unwrap_or_else(|| active_connection.connection.clone()),
+ )
+ });
+
+ if !should_retain {
+ did_change = true;
+ }
+
+ should_retain
+ });
+
+ for active_connection in active_connections {
+ let should_create_watcher = self
+ .watched_connections
+ .get(&active_connection.agent_id)
+ .is_none_or(|watched_connection| {
+ watched_connection
+ .connection
+ .upgrade()
+ .is_none_or(|connection| {
+ !Rc::ptr_eq(&connection, &active_connection.connection)
+ })
+ });
+
+ if !should_create_watcher {
+ continue;
}
- }
- if let Some(connection) = active_connection.connection.upgrade() {
+ let agent_id = active_connection.agent_id.clone();
+ let connection = active_connection.connection;
let mut receiver = connection.subscribe();
- let task = cx.spawn(async move |this, cx| {
- while let Ok(message) = receiver.recv().await {
- this.update(cx, |this, cx| {
- this.push_stream_message(message, cx);
- })
- .ok();
+ let task = cx.spawn({
+ let agent_id = agent_id.clone();
+ async move |this, cx| {
+ while let Ok(message) = receiver.recv().await {
+ this.update(cx, |this, cx| {
+ this.push_stream_message(&agent_id, message, cx);
+ })
+ .ok();
+ }
}
});
- self.watched_connection = Some(WatchedConnection {
- agent_id: active_connection.agent_id.clone(),
- messages: vec![],
- list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
- connection: active_connection.connection.clone(),
- incoming_request_methods: HashMap::default(),
- outgoing_request_methods: HashMap::default(),
- _task: task,
- });
+ self.watched_connections.insert(
+ agent_id.clone(),
+ WatchedConnection {
+ agent_id,
+ messages: vec![],
+ list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
+ connection: Rc::downgrade(&connection),
+ incoming_request_methods: HashMap::default(),
+ outgoing_request_methods: HashMap::default(),
+ _task: task,
+ },
+ );
+ did_change = true;
+ }
+
+ let previous_selected_connection = self.selected_connection.clone();
+ self.selected_connection = self
+ .selected_connection
+ .clone()
+ .filter(|agent_id| self.watched_connections.contains_key(agent_id))
+ .or_else(|| self.watched_connections.keys().next().cloned());
+
+ if self.selected_connection != previous_selected_connection {
+ self.expanded.clear();
+ did_change = true;
+ }
+
+ if did_change {
+ cx.notify();
}
}
- fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
- let Some(connection) = self.watched_connection.as_mut() else {
+ fn select_connection(&mut self, agent_id: Option<AgentId>, cx: &mut Context<Self>) {
+ if self.selected_connection == agent_id {
+ return;
+ }
+
+ self.selected_connection = agent_id;
+ self.expanded.clear();
+ cx.notify();
+ }
+
+ fn restart_selected_connection(&mut self, cx: &mut Context<Self>) {
+ let Some(agent_id) = self.selected_connection.clone() else {
+ return;
+ };
+ let Some(workspace) = self.workspace.upgrade() else {
+ return;
+ };
+
+ workspace.update(cx, |workspace, cx| {
+ let Some(panel) = workspace.panel::<AgentPanel>(cx) else {
+ return;
+ };
+
+ let fs = workspace.app_state().fs.clone();
+ let (thread_store, connection_store) = {
+ let panel = panel.read(cx);
+ (
+ panel.thread_store().clone(),
+ panel.connection_store().clone(),
+ )
+ };
+ let agent = Agent::from(agent_id.clone());
+ let server = agent.server(fs, thread_store);
+
+ connection_store.update(cx, |store, cx| {
+ store.restart_connection(agent, server, cx);
+ });
+ });
+ }
+
+ fn selected_connection_status(
+ &self,
+ cx: &App,
+ ) -> Option<agent_ui::agent_connection_store::AgentConnectionStatus> {
+ let agent_id = self.selected_connection.clone()?;
+ let connection_store = self.connection_store.as_ref()?;
+ let agent = Agent::from(agent_id);
+ Some(connection_store.read(cx).connection_status(&agent, cx))
+ }
+
+ fn selected_watched_connection(&self) -> Option<&WatchedConnection> {
+ let selected_connection = self.selected_connection.as_ref()?;
+ self.watched_connections.get(selected_connection)
+ }
+
+ fn selected_watched_connection_mut(&mut self) -> Option<&mut WatchedConnection> {
+ let selected_connection = self.selected_connection.clone()?;
+ self.watched_connections.get_mut(&selected_connection)
+ }
+
+ fn connection_menu_entries(&self) -> Vec<SharedString> {
+ let mut entries: Vec<_> = self
+ .watched_connections
+ .values()
+ .map(|connection| SharedString::from(connection.agent_id.0.clone()))
+ .collect();
+ entries.sort();
+ entries
+ }
+
+ fn push_stream_message(
+ &mut self,
+ agent_id: &AgentId,
+ stream_message: acp::StreamMessage,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(connection) = self.watched_connections.get_mut(agent_id) else {
return;
};
let language_registry = self.project.read(cx).languages().clone();
@@ -230,7 +372,7 @@ impl AcpTools {
}
fn serialize_observed_messages(&self) -> Option<String> {
- let connection = self.watched_connection.as_ref()?;
+ let connection = self.selected_watched_connection()?;
let messages: Vec<serde_json::Value> = connection
.messages
@@ -258,7 +400,7 @@ impl AcpTools {
}
fn clear_messages(&mut self, cx: &mut Context<Self>) {
- if let Some(connection) = self.watched_connection.as_mut() {
+ if let Some(connection) = self.selected_watched_connection_mut() {
connection.messages.clear();
connection.list_state.reset(0);
self.expanded.clear();
@@ -266,13 +408,55 @@ impl AcpTools {
}
}
+ fn selected_connection_label(&self) -> SharedString {
+ self.selected_connection
+ .as_ref()
+ .map(|agent_id| SharedString::from(agent_id.0.clone()))
+ .unwrap_or_else(|| SharedString::from("No connection selected"))
+ }
+
+ fn connection_menu(&self, window: &mut Window, cx: &mut Context<Self>) -> Entity<ContextMenu> {
+ let entries = self.connection_menu_entries();
+ let selected_connection = self.selected_connection.clone();
+ let acp_tools = cx.entity().downgrade();
+
+ ContextMenu::build(window, cx, move |mut menu, _window, _cx| {
+ if entries.is_empty() {
+ return menu.entry("No active connections", None, |_, _| {});
+ }
+
+ for entry in &entries {
+ let label = entry.clone();
+ let selected = selected_connection
+ .as_ref()
+ .is_some_and(|agent_id| agent_id.0.as_ref() == label.as_ref());
+ let weak_acp_tools = acp_tools.clone();
+ menu = menu.toggleable_entry(
+ label.clone(),
+ selected,
+ IconPosition::Start,
+ None,
+ move |_window, cx| {
+ weak_acp_tools
+ .update(cx, |this, cx| {
+ this.select_connection(Some(AgentId(label.clone())), cx);
+ })
+ .ok();
+ },
+ );
+ }
+
+ menu
+ })
+ }
+
fn render_message(
&mut self,
index: usize,
window: &mut Window,
cx: &mut Context<Self>,
) -> AnyElement {
- let Some(connection) = self.watched_connection.as_ref() else {
+ let Some(connection) = self.selected_watched_connection() else {
return Empty.into_any();
};
@@ -314,13 +498,14 @@ impl AcpTools {
this.expanded.remove(&index);
} else {
this.expanded.insert(index);
- let Some(connection) = &mut this.watched_connection else {
+ let project = this.project.clone();
+ let Some(connection) = this.selected_watched_connection_mut() else {
return;
};
let Some(message) = connection.messages.get_mut(index) else {
return;
};
- message.expanded(this.project.read(cx).languages().clone(), cx);
+ message.expanded(project.read(cx).languages().clone(), cx);
connection.list_state.scroll_to_reveal_item(index);
}
cx.notify()
@@ -485,8 +670,7 @@ impl Item for AcpTools {
fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
format!(
"ACP: {}",
- self.watched_connection
- .as_ref()
+ self.selected_watched_connection()
.map_or("Disconnected", |connection| connection.agent_id.0.as_ref())
)
.into()
@@ -509,7 +693,23 @@ impl Render for AcpTools {
.track_focus(&self.focus_handle)
.size_full()
.bg(cx.theme().colors().editor_background)
- .child(match self.watched_connection.as_ref() {
+ .child(
+ h_flex()
+ .px_3()
+ .py_2()
+ .border_b_1()
+ .border_color(cx.theme().colors().border)
+ .child(
+ DropdownMenu::new(
+ "acp-connection-selector",
+ self.selected_connection_label(),
+ self.connection_menu(window, cx),
+ )
+ .style(DropdownStyle::Subtle)
+ .disabled(self.watched_connections.is_empty()),
+ ),
+ )
+ .child(match self.selected_watched_connection() {
Some(connection) => {
if connection.messages.is_empty() {
h_flex()
@@ -561,14 +761,33 @@ impl Render for AcpToolsToolbarItemView {
};
let acp_tools = acp_tools.clone();
- let has_messages = acp_tools
- .read(cx)
- .watched_connection
- .as_ref()
- .is_some_and(|connection| !connection.messages.is_empty());
+ let (has_messages, can_restart) = {
+ let acp_tools = acp_tools.read(cx);
+ (
+ acp_tools
+ .selected_watched_connection()
+ .is_some_and(|connection| !connection.messages.is_empty()),
+ acp_tools.selected_connection_status(cx)
+ != Some(agent_ui::agent_connection_store::AgentConnectionStatus::Connecting),
+ )
+ };
h_flex()
.gap_2()
+ .child(
+ IconButton::new("restart_connection", IconName::RotateCw)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text("Restart Connection"))
+ .disabled(!can_restart)
+ .on_click(cx.listener({
+ let acp_tools = acp_tools.clone();
+ move |_this, _, _window, cx| {
+ acp_tools.update(cx, |acp_tools, cx| {
+ acp_tools.restart_selected_connection(cx);
+ });
+ }
+ })),
+ )
.child({
let message = acp_tools
.read(cx)