diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 4b66c4eba62e9ecec7500c934eee3ae776f58cef..19eb651e30170aa395ac3eeba51c54c6f37f2055 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -667,7 +667,7 @@ impl NativeAgent { fn handle_context_server_store_updated( &mut self, _store: Entity, - _event: &project::context_server_store::Event, + _event: &project::context_server_store::ServerStatusChangedEvent, cx: &mut Context, ) { self.update_available_commands(cx); diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index c2c398c28e8bac6ff6de85cea6acfe4308515922..12ad642cfca6d87aa29f219951e45d402d98943d 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -247,31 +247,29 @@ impl ContextServerRegistry { fn handle_context_server_store_event( &mut self, _: Entity, - event: &project::context_server_store::Event, + event: &project::context_server_store::ServerStatusChangedEvent, cx: &mut Context, ) { - match event { - project::context_server_store::Event::ServerStatusChanged { server_id, status } => { - match status { - ContextServerStatus::Starting => {} - ContextServerStatus::Running => { - self.reload_tools_for_server(server_id.clone(), cx); - self.reload_prompts_for_server(server_id.clone(), cx); + let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event; + + match status { + ContextServerStatus::Starting => {} + ContextServerStatus::Running => { + self.reload_tools_for_server(server_id.clone(), cx); + self.reload_prompts_for_server(server_id.clone(), cx); + } + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { + if let Some(registered_server) = self.registered_servers.remove(server_id) { + if !registered_server.tools.is_empty() { + cx.emit(ContextServerRegistryEvent::ToolsChanged); } - ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { - if let Some(registered_server) = self.registered_servers.remove(server_id) { - if !registered_server.tools.is_empty() { - cx.emit(ContextServerRegistryEvent::ToolsChanged); - } - if !registered_server.prompts.is_empty() { - cx.emit(ContextServerRegistryEvent::PromptsChanged); - } - } - cx.notify(); + if !registered_server.prompts.is_empty() { + cx.emit(ContextServerRegistryEvent::PromptsChanged); } } + cx.notify(); } - } + }; } } diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 9e101b9e8d8de5993eaaedbe43bd3a04e4217d37..dec56a789a6fa0a1d52da9ea981c17f66bf596f2 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -518,25 +518,7 @@ impl AgentConfiguration { window: &mut Window, cx: &mut Context, ) -> impl IntoElement { - let mut context_server_ids = self - .context_server_store - .read(cx) - .server_ids(cx) - .into_iter() - .collect::>(); - - // Sort context servers: ones without mcp-server- prefix first, then prefixed ones - context_server_ids.sort_by(|a, b| { - const MCP_PREFIX: &str = "mcp-server-"; - match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) { - // If one has mcp-server- prefix and other doesn't, non-mcp comes first - (Some(_), None) => std::cmp::Ordering::Greater, - (None, Some(_)) => std::cmp::Ordering::Less, - // If both have same prefix status, sort by appropriate key - (Some(a), Some(b)) => a.cmp(b), - (None, None) => a.0.cmp(&b.0), - } - }); + let context_server_ids = self.context_server_store.read(cx).server_ids(); let add_server_popover = PopoverMenu::new("add-server-popover") .trigger( @@ -594,7 +576,7 @@ impl AgentConfiguration { .pr_5() .w_full() .gap_1() - .map(|mut parent| { + .map(|parent| { if context_server_ids.is_empty() { parent.child( h_flex() @@ -611,23 +593,17 @@ impl AgentConfiguration { ), ) } else { - for (index, context_server_id) in - context_server_ids.into_iter().enumerate() - { - if index > 0 { - parent = parent.child( - Divider::horizontal() - .color(DividerColor::BorderFaded) - .into_any_element(), - ); - } - parent = parent.child(self.render_context_server( - context_server_id, - window, - cx, - )); - } - parent + parent.children(itertools::intersperse_with( + context_server_ids.iter().cloned().map(|context_server_id| { + self.render_context_server(context_server_id, window, cx) + .into_any_element() + }), + || { + Divider::horizontal() + .color(DividerColor::BorderFaded) + .into_any_element() + }, + )) } }), ) @@ -637,7 +613,7 @@ impl AgentConfiguration { &self, context_server_id: ContextServerId, window: &mut Window, - cx: &mut Context, + cx: &Context, ) -> impl use<> + IntoElement { let server_status = self .context_server_store diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 3069d55718e4b11f87a2ce73703a593d1f7acf4c..78c032a565522a7eac145add3f65568d559ceb24 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -880,32 +880,32 @@ fn wait_for_context_server( let (tx, rx) = futures::channel::oneshot::channel(); let tx = Arc::new(Mutex::new(Some(tx))); - let subscription = cx.subscribe(context_server_store, move |_, event, _cx| match event { - project::context_server_store::Event::ServerStatusChanged { server_id, status } => { - match status { - ContextServerStatus::Running => { - if server_id == &context_server_id - && let Some(tx) = tx.lock().unwrap().take() - { - let _ = tx.send(Ok(())); - } + let subscription = cx.subscribe(context_server_store, move |_, event, _cx| { + let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event; + + match status { + ContextServerStatus::Running => { + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Ok(())); } - ContextServerStatus::Stopped => { - if server_id == &context_server_id - && let Some(tx) = tx.lock().unwrap().take() - { - let _ = tx.send(Err("Context server stopped running".into())); - } + } + ContextServerStatus::Stopped => { + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Err("Context server stopped running".into())); } - ContextServerStatus::Error(error) => { - if server_id == &context_server_id - && let Some(tx) = tx.lock().unwrap().take() - { - let _ = tx.send(Err(error.clone())); - } + } + ContextServerStatus::Error(error) => { + if server_id == &context_server_id + && let Some(tx) = tx.lock().unwrap().take() + { + let _ = tx.send(Err(error.clone())); } - _ => {} } + _ => {} } }); diff --git a/crates/assistant_text_thread/src/text_thread_store.rs b/crates/assistant_text_thread/src/text_thread_store.rs index 248a57d6861ccf2af30a80d1c62687f943542c12..8a9a34cf65958d32545677d77f44297fe9d3ade8 100644 --- a/crates/assistant_text_thread/src/text_thread_store.rs +++ b/crates/assistant_text_thread/src/text_thread_store.rs @@ -888,29 +888,27 @@ impl TextThreadStore { fn handle_context_server_event( &mut self, context_server_store: Entity, - event: &project::context_server_store::Event, + event: &project::context_server_store::ServerStatusChangedEvent, cx: &mut Context, ) { - match event { - project::context_server_store::Event::ServerStatusChanged { server_id, status } => { - match status { - ContextServerStatus::Running => { - self.load_context_server_slash_commands( - server_id.clone(), - context_server_store, - cx, - ); - } - ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { - if let Some(slash_command_ids) = - self.context_server_slash_command_ids.remove(server_id) - { - self.slash_commands.remove(&slash_command_ids); - } - } - _ => {} + let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event; + + match status { + ContextServerStatus::Running => { + self.load_context_server_slash_commands( + server_id.clone(), + context_server_store, + cx, + ); + } + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { + if let Some(slash_command_ids) = + self.context_server_slash_command_ids.remove(server_id) + { + self.slash_commands.remove(&slash_command_ids); } } + _ => {} } } diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index 3eda664c4d422007782dd1d2fc91062ff4a8638e..93528e6c9bc9f8a9e3492e137c3a7a386f9d0a60 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -10,6 +10,7 @@ use collections::{HashMap, HashSet}; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::{FutureExt as _, future::join_all}; use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions}; +use itertools::Itertools; use registry::ContextServerDescriptorRegistry; use remote::RemoteClient; use rpc::{AnyProtoClient, TypedEnvelope, proto}; @@ -203,6 +204,7 @@ pub struct ContextServerStore { state: ContextServerStoreState, context_server_settings: HashMap, ContextServerSettings>, servers: HashMap, + server_ids: Vec, worktree_store: Entity, project: Option>, registry: Entity, @@ -212,14 +214,12 @@ pub struct ContextServerStore { _subscriptions: Vec, } -pub enum Event { - ServerStatusChanged { - server_id: ContextServerId, - status: ContextServerStatus, - }, +pub struct ServerStatusChangedEvent { + pub server_id: ContextServerId, + pub status: ContextServerStatus, } -impl EventEmitter for ContextServerStore {} +impl EventEmitter for ContextServerStore {} impl ContextServerStore { pub fn local( @@ -394,6 +394,7 @@ impl ContextServerStore { registry, needs_server_update: false, servers: HashMap::default(), + server_ids: Default::default(), update_servers_task: None, context_server_factory, }; @@ -426,8 +427,16 @@ impl ContextServerStore { self.servers.get(id).map(|state| state.configuration()) } - pub fn server_ids(&self, cx: &App) -> HashSet { - self.servers + /// Returns a sorted slice of available unique context server IDs. Within the + /// slice, context servers which have `mcp-server-` as a prefix in their ID will + /// appear after servers that do not have this prefix in their ID. + pub fn server_ids(&self) -> &[ContextServerId] { + self.server_ids.as_slice() + } + + fn populate_server_ids(&mut self, cx: &App) { + self.server_ids = self + .servers .keys() .cloned() .chain( @@ -437,7 +446,27 @@ impl ContextServerStore { .into_iter() .map(|(id, _)| ContextServerId(id)), ) - .collect() + .chain( + self.context_server_settings + .keys() + .map(|id| ContextServerId(id.clone())), + ) + .unique() + .sorted_unstable_by( + // Sort context servers: ones without mcp-server- prefix first, then prefixed ones + |a, b| { + const MCP_PREFIX: &str = "mcp-server-"; + match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) { + // If one has mcp-server- prefix and other doesn't, non-mcp comes first + (Some(_), None) => std::cmp::Ordering::Greater, + (None, Some(_)) => std::cmp::Ordering::Less, + // If both have same prefix status, sort by appropriate key + (Some(a), Some(b)) => a.cmp(b), + (None, None) => a.0.cmp(&b.0), + } + }, + ) + .collect(); } pub fn running_servers(&self) -> Vec> { @@ -591,7 +620,7 @@ impl ContextServerStore { .remove(id) .context("Context server not found")?; drop(state); - cx.emit(Event::ServerStatusChanged { + cx.emit(ServerStatusChangedEvent { server_id: id.clone(), status: ContextServerStatus::Stopped, }); @@ -808,7 +837,7 @@ impl ContextServerStore { ) { let status = ContextServerStatus::from_state(&state); self.servers.insert(id.clone(), state); - cx.emit(Event::ServerStatusChanged { + cx.emit(ServerStatusChangedEvent { server_id: id, status, }); @@ -825,6 +854,7 @@ impl ContextServerStore { } this.update(cx, |this, cx| { + this.populate_server_ids(cx); this.update_servers_task.take(); if this.needs_server_update { this.available_context_servers_changed(cx); diff --git a/crates/project/tests/integration/context_server_store.rs b/crates/project/tests/integration/context_server_store.rs index ca54abe05f9f4d51382dfdc5898c58de3d0d8bbb..56bdaed41cd77b665d316491e051582c7ccc078a 100644 --- a/crates/project/tests/integration/context_server_store.rs +++ b/crates/project/tests/integration/context_server_store.rs @@ -553,6 +553,92 @@ async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_server_ids_includes_disabled_servers(cx: &mut TestAppContext) { + const ENABLED_SERVER_ID: &str = "enabled-server"; + const DISABLED_SERVER_ID: &str = "disabled-server"; + + let enabled_server_id = ContextServerId(ENABLED_SERVER_ID.into()); + let disabled_server_id = ContextServerId(DISABLED_SERVER_ID.into()); + + let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await; + + let executor = cx.executor(); + let store = project.read_with(cx, |project, _| project.context_server_store()); + store.update(cx, |store, _| { + store.set_context_server_factory(Box::new(move |id, _| { + Arc::new(ContextServer::new( + id.clone(), + Arc::new(create_fake_transport(id.0.to_string(), executor.clone())), + )) + })); + }); + + // Configure one enabled and one disabled server + set_context_server_configuration( + vec![ + ( + enabled_server_id.0.clone(), + settings::ContextServerSettingsContent::Stdio { + enabled: true, + remote: false, + command: ContextServerCommand { + path: "somebinary".into(), + args: vec![], + env: None, + timeout: None, + }, + }, + ), + ( + disabled_server_id.0.clone(), + settings::ContextServerSettingsContent::Stdio { + enabled: false, + remote: false, + command: ContextServerCommand { + path: "somebinary".into(), + args: vec![], + env: None, + timeout: None, + }, + }, + ), + ], + cx, + ); + + cx.run_until_parked(); + + // Verify that server_ids includes both enabled and disabled servers + cx.update(|cx| { + let server_ids = store.read(cx).server_ids().to_vec(); + assert!( + server_ids.contains(&enabled_server_id), + "server_ids should include enabled server" + ); + assert!( + server_ids.contains(&disabled_server_id), + "server_ids should include disabled server" + ); + }); + + // Verify that the enabled server is running and the disabled server is not + cx.read(|cx| { + assert_eq!( + store.read(cx).status_for_server(&enabled_server_id), + Some(ContextServerStatus::Running), + "enabled server should be running" + ); + // Disabled server should not be in the servers map (status returns None) + // but should still be in server_ids + assert_eq!( + store.read(cx).status_for_server(&disabled_server_id), + None, + "disabled server should not have a status (not in servers map)" + ); + }); +} + fn set_context_server_configuration( context_servers: Vec<(Arc, settings::ContextServerSettingsContent)>, cx: &mut TestAppContext, @@ -796,26 +882,25 @@ fn assert_server_events( let expected_event_count = expected_events.len(); let subscription = cx.subscribe(store, { let received_event_count = received_event_count.clone(); - move |_, event, _| match event { - Event::ServerStatusChanged { + move |_, event, _| { + let ServerStatusChangedEvent { server_id: actual_server_id, status: actual_status, - } => { - let (expected_server_id, expected_status) = &expected_events[ix]; - - assert_eq!( - actual_server_id, expected_server_id, - "Expected different server id at index {}", - ix - ); - assert_eq!( - actual_status, expected_status, - "Expected different status at index {}", - ix - ); - ix += 1; - *received_event_count.borrow_mut() += 1; - } + } = event; + let (expected_server_id, expected_status) = &expected_events[ix]; + + assert_eq!( + actual_server_id, expected_server_id, + "Expected different server id at index {}", + ix + ); + assert_eq!( + actual_status, expected_status, + "Expected different status at index {}", + ix + ); + ix += 1; + *received_event_count.borrow_mut() += 1; } }); ServerEvents {