From eba811a127170a0fba4ebdf47bf52865ee5f5dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torstein=20S=C3=B8rnes?= Date: Tue, 16 Dec 2025 22:44:39 +0100 Subject: [PATCH] Add support for MCP tools/list_changed notification (#42453) ## Summary This PR adds support for the MCP (Model Context Protocol) `notifications/tools/list_changed` notification, enabling dynamic tool discovery when MCP servers add, remove, or modify their available tools at runtime. ## Release Notes: - Improved: MCP tools are now automatically reloaded when a context server sends a `tools/list_changed` notification, eliminating the need to restart the server to discover new tools. ## Changes - Register a notification handler for `notifications/tools/list_changed` in `ContextServerRegistry` - Automatically reload tools when the notification is received - Handler is registered both on initial server startup and when a server transitions to `Running` status ## Motivation The MCP specification includes a `notifications/tools/list_changed` notification to inform clients when the list of available tools has changed. Previously, Zed's agent would only load tools once when a context server started. This meant that: 1. If an MCP server dynamically registered new tools after initialization, they would not be available to the agent 2. The only way to refresh tools was to restart the entire context server 3. Tools that were removed or modified would remain in the old state until restart ## Implementation Details The implementation follows these steps: 1. When a context server transitions to `Running` status, register a notification handler for `notifications/tools/list_changed` 2. The handler captures a weak reference to the `ContextServerRegistry` entity 3. When the notification is received, spawn a task that calls `reload_tools_for_server` with the server ID 4. The existing `reload_tools_for_server` method handles fetching the updated tool list and notifying observers This approach is minimal and reuses existing tool-loading infrastructure. ## Testing - [x] Code compiles with `./script/clippy -p agent` - The notification handler infrastructure already exists and is tested in the codebase - The `reload_tools_for_server` method is already tested and working ## Benefits - Improves developer experience by enabling hot-reloading of MCP tools - Aligns with the MCP specification's capability negotiation system - No breaking changes to existing functionality - Enables more flexible and dynamic MCP server implementations ## Related Issues This implements part of the MCP specification that was already defined in the type system but not wired up to actually handle the notifications. --------- Co-authored-by: Agus Zubiaga --- Cargo.lock | 1 + .../src/tools/context_server_registry.rs | 68 +++++++++--- crates/context_server/Cargo.toml | 1 + crates/context_server/src/client.rs | 103 +++++++++++++++--- crates/context_server/src/context_server.rs | 16 --- crates/context_server/src/protocol.rs | 6 +- 6 files changed, 148 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6908a8ed5185ea71cc51a34d63990decaaf082d9..080a6a4cf4183fb5cade03ba36072b448ab4b70a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3623,6 +3623,7 @@ dependencies = [ "serde", "serde_json", "settings", + "slotmap", "smol", "tempfile", "terminal", diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 735a47ae9fb99decbf97beb74a590f13f8f74878..3b01b2feb7dd36615a8ba7c63d81a81694e0d268 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -2,7 +2,7 @@ use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Result, anyhow, bail}; use collections::{BTreeMap, HashMap}; -use context_server::ContextServerId; +use context_server::{ContextServerId, client::NotificationSubscription}; use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use std::sync::Arc; @@ -31,17 +31,7 @@ struct RegisteredContextServer { prompts: BTreeMap, load_tools: Task>, load_prompts: Task>, -} - -impl RegisteredContextServer { - fn new() -> Self { - Self { - tools: BTreeMap::default(), - prompts: BTreeMap::default(), - load_tools: Task::ready(Ok(())), - load_prompts: Task::ready(Ok(())), - } - } + _tools_updated_subscription: Option, } impl ContextServerRegistry { @@ -111,10 +101,57 @@ impl ContextServerRegistry { fn get_or_register_server( &mut self, server_id: &ContextServerId, + cx: &mut Context, ) -> &mut RegisteredContextServer { self.registered_servers .entry(server_id.clone()) - .or_insert_with(RegisteredContextServer::new) + .or_insert_with(|| Self::init_registered_server(server_id, &self.server_store, cx)) + } + + fn init_registered_server( + server_id: &ContextServerId, + server_store: &Entity, + cx: &mut Context, + ) -> RegisteredContextServer { + let tools_updated_subscription = server_store + .read(cx) + .get_running_server(server_id) + .and_then(|server| { + let client = server.client()?; + + if !client.capable(context_server::protocol::ServerCapability::Tools) { + return None; + } + + let server_id = server.id(); + let this = cx.entity().downgrade(); + + Some(client.on_notification( + "notifications/tools/list_changed", + Box::new(move |_params, cx: AsyncApp| { + let server_id = server_id.clone(); + let this = this.clone(); + cx.spawn(async move |cx| { + this.update(cx, |this, cx| { + log::info!( + "Received tools/list_changed notification for server {}", + server_id + ); + this.reload_tools_for_server(server_id, cx); + }) + }) + .detach(); + }), + )) + }); + + RegisteredContextServer { + tools: BTreeMap::default(), + prompts: BTreeMap::default(), + load_tools: Task::ready(Ok(())), + load_prompts: Task::ready(Ok(())), + _tools_updated_subscription: tools_updated_subscription, + } } fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context) { @@ -124,11 +161,12 @@ impl ContextServerRegistry { let Some(client) = server.client() else { return; }; + if !client.capable(context_server::protocol::ServerCapability::Tools) { return; } - let registered_server = self.get_or_register_server(&server_id); + let registered_server = self.get_or_register_server(&server_id, cx); registered_server.load_tools = cx.spawn(async move |this, cx| { let response = client .request::(()) @@ -167,7 +205,7 @@ impl ContextServerRegistry { return; } - let registered_server = self.get_or_register_server(&server_id); + let registered_server = self.get_or_register_server(&server_id, cx); registered_server.load_prompts = cx.spawn(async move |this, cx| { let response = client diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index cb48b7e6f7d000ed7f2db7aaf3cfe4d6317fe278..539b873c3527b5a01f1dfcf7b768f0758dc869b5 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -29,6 +29,7 @@ schemars.workspace = true serde_json.workspace = true serde.workspace = true settings.workspace = true +slotmap.workspace = true smol.workspace = true tempfile.workspace = true url = { workspace = true, features = ["serde"] } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index f891e96250f3334540aa859fe438c87297fc0100..605f24178916faa5173c32c28be6c80ee625cb6c 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -6,6 +6,7 @@ use parking_lot::Mutex; use postage::barrier; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{Value, value::RawValue}; +use slotmap::SlotMap; use smol::channel; use std::{ fmt, @@ -50,7 +51,7 @@ pub(crate) struct Client { next_id: AtomicI32, outbound_tx: channel::Sender, name: Arc, - notification_handlers: Arc>>, + subscription_set: Arc>, response_handlers: Arc>>>, #[allow(clippy::type_complexity)] #[allow(dead_code)] @@ -191,21 +192,20 @@ impl Client { let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); - let notification_handlers = - Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); + let subscription_set = Arc::new(Mutex::new(NotificationSubscriptionSet::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ - let notification_handlers = notification_handlers.clone(); + let subscription_set = subscription_set.clone(); let response_handlers = response_handlers.clone(); let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { Self::handle_input( transport, - notification_handlers, + subscription_set, request_handlers, response_handlers, cx, @@ -236,7 +236,7 @@ impl Client { Ok(Self { server_id, - notification_handlers, + subscription_set, response_handlers, name: server_name, next_id: Default::default(), @@ -257,7 +257,7 @@ impl Client { /// to pending requests) and notifications (which trigger registered handlers). async fn handle_input( transport: Arc, - notification_handlers: Arc>>, + subscription_set: Arc>, request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, @@ -282,10 +282,11 @@ impl Client { handler(Ok(message.to_string())); } } else if let Ok(notification) = serde_json::from_str::(&message) { - let mut notification_handlers = notification_handlers.lock(); - if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { - handler(notification.params.unwrap_or(Value::Null), cx.clone()); - } + subscription_set.lock().notify( + ¬ification.method, + notification.params.unwrap_or(Value::Null), + cx, + ) } else { log::error!("Unhandled JSON from context_server: {}", message); } @@ -451,12 +452,18 @@ impl Client { Ok(()) } + #[must_use] pub fn on_notification( &self, method: &'static str, f: Box, - ) { - self.notification_handlers.lock().insert(method, f); + ) -> NotificationSubscription { + let mut notification_subscriptions = self.subscription_set.lock(); + + NotificationSubscription { + id: notification_subscriptions.add_handler(method, f), + set: self.subscription_set.clone(), + } } } @@ -485,3 +492,73 @@ impl fmt::Debug for Client { .finish_non_exhaustive() } } + +slotmap::new_key_type! { + struct NotificationSubscriptionId; +} + +#[derive(Default)] +pub struct NotificationSubscriptionSet { + // we have very few subscriptions at the moment + methods: Vec<(&'static str, Vec)>, + handlers: SlotMap, +} + +impl NotificationSubscriptionSet { + #[must_use] + fn add_handler( + &mut self, + method: &'static str, + handler: NotificationHandler, + ) -> NotificationSubscriptionId { + let id = self.handlers.insert(handler); + if let Some((_, handler_ids)) = self + .methods + .iter_mut() + .find(|(probe_method, _)| method == *probe_method) + { + debug_assert!( + handler_ids.len() < 20, + "Too many MCP handlers for {}. Consider using a different data structure.", + method + ); + + handler_ids.push(id); + } else { + self.methods.push((method, vec![id])); + }; + id + } + + fn notify(&mut self, method: &str, payload: Value, cx: &mut AsyncApp) { + let Some((_, handler_ids)) = self + .methods + .iter_mut() + .find(|(probe_method, _)| method == *probe_method) + else { + return; + }; + + for handler_id in handler_ids { + if let Some(handler) = self.handlers.get_mut(*handler_id) { + handler(payload.clone(), cx.clone()); + } + } + } +} + +pub struct NotificationSubscription { + id: NotificationSubscriptionId, + set: Arc>, +} + +impl Drop for NotificationSubscription { + fn drop(&mut self) { + let mut set = self.set.lock(); + set.handlers.remove(self.id); + set.methods.retain_mut(|(_, handler_ids)| { + handler_ids.retain(|id| *id != self.id); + !handler_ids.is_empty() + }); + } +} diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 553e845df87a2fec30b1afbffa05b970d5d672f6..92804549c69b01dd3729efb3a0b47905cd73d813 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -96,22 +96,6 @@ impl ContextServer { self.initialize(self.new_client(cx)?).await } - /// Starts the context server, making sure handlers are registered before initialization happens - pub async fn start_with_handlers( - &self, - notification_handlers: Vec<( - &'static str, - Box, - )>, - cx: &AsyncApp, - ) -> Result<()> { - let client = self.new_client(cx)?; - for (method, handler) in notification_handlers { - client.on_notification(method, handler); - } - self.initialize(client).await - } - fn new_client(&self, cx: &AsyncApp) -> Result { Ok(match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 5355f20f620b5bed76bf945e863fdb5cbcc2ff43..a218a8a3e0e6352997e4152214077cb3851317b3 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -12,7 +12,7 @@ use futures::channel::oneshot; use gpui::AsyncApp; use serde_json::Value; -use crate::client::Client; +use crate::client::{Client, NotificationSubscription}; use crate::types::{self, Notification, Request}; pub struct ModelContextProtocol { @@ -119,7 +119,7 @@ impl InitializedContextServerProtocol { &self, method: &'static str, f: Box, - ) { - self.inner.on_notification(method, f); + ) -> NotificationSubscription { + self.inner.on_notification(method, f) } }