agent: Fix MCP server handler subscription race condition (#32133)

Jonathan LEI , Bennet Bo Fenner , and Bennet Bo Fenner created

Closes #32132

Release Notes:

- Fixed MCP server handler subscription race condition causing tools to
not load.

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

crates/agent/src/thread_store.rs                     | 109 +++++++------
crates/assistant_context_editor/src/context_store.rs | 110 +++++++------
2 files changed, 113 insertions(+), 106 deletions(-)

Detailed changes

crates/agent/src/thread_store.rs 🔗

@@ -514,11 +514,14 @@ impl ThreadStore {
     }
 
     fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
-        cx.subscribe(
-            &self.project.read(cx).context_server_store(),
-            Self::handle_context_server_event,
-        )
-        .detach();
+        let context_server_store = self.project.read(cx).context_server_store();
+        cx.subscribe(&context_server_store, Self::handle_context_server_event)
+            .detach();
+
+        // Check for any servers that were already running before the handler was registered
+        for server in context_server_store.read(cx).running_servers() {
+            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
+        }
     }
 
     fn handle_context_server_event(
@@ -533,55 +536,7 @@ impl ThreadStore {
                 match status {
                     ContextServerStatus::Starting => {}
                     ContextServerStatus::Running => {
-                        if let Some(server) =
-                            context_server_store.read(cx).get_running_server(server_id)
-                        {
-                            let context_server_manager = context_server_store.clone();
-                            cx.spawn({
-                                let server = server.clone();
-                                let server_id = server_id.clone();
-                                async move |this, cx| {
-                                    let Some(protocol) = server.client() else {
-                                        return;
-                                    };
-
-                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
-                                        if let Some(tools) = protocol.list_tools().await.log_err() {
-                                            let tool_ids = tool_working_set
-                                                .update(cx, |tool_working_set, _| {
-                                                    tools
-                                                        .tools
-                                                        .into_iter()
-                                                        .map(|tool| {
-                                                            log::info!(
-                                                                "registering context server tool: {:?}",
-                                                                tool.name
-                                                            );
-                                                            tool_working_set.insert(Arc::new(
-                                                                ContextServerTool::new(
-                                                                    context_server_manager.clone(),
-                                                                    server.id(),
-                                                                    tool,
-                                                                ),
-                                                            ))
-                                                        })
-                                                        .collect::<Vec<_>>()
-                                                })
-                                                .log_err();
-
-                                            if let Some(tool_ids) = tool_ids {
-                                                this.update(cx, |this, _| {
-                                                    this.context_server_tool_ids
-                                                        .insert(server_id, tool_ids);
-                                                })
-                                                .log_err();
-                                            }
-                                        }
-                                    }
-                                }
-                            })
-                            .detach();
-                        }
+                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
                     }
                     ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
                         if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
@@ -594,6 +549,52 @@ impl ThreadStore {
             }
         }
     }
+
+    fn load_context_server_tools(
+        &self,
+        server_id: ContextServerId,
+        context_server_store: Entity<ContextServerStore>,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
+            return;
+        };
+        let tool_working_set = self.tools.clone();
+        cx.spawn(async move |this, cx| {
+            let Some(protocol) = server.client() else {
+                return;
+            };
+
+            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
+                if let Some(tools) = protocol.list_tools().await.log_err() {
+                    let tool_ids = tool_working_set
+                        .update(cx, |tool_working_set, _| {
+                            tools
+                                .tools
+                                .into_iter()
+                                .map(|tool| {
+                                    log::info!("registering context server tool: {:?}", tool.name);
+                                    tool_working_set.insert(Arc::new(ContextServerTool::new(
+                                        context_server_store.clone(),
+                                        server.id(),
+                                        tool,
+                                    )))
+                                })
+                                .collect::<Vec<_>>()
+                        })
+                        .log_err();
+
+                    if let Some(tool_ids) = tool_ids {
+                        this.update(cx, |this, _| {
+                            this.context_server_tool_ids.insert(server_id, tool_ids);
+                        })
+                        .log_err();
+                    }
+                }
+            }
+        })
+        .detach();
+    }
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/assistant_context_editor/src/context_store.rs 🔗

@@ -809,74 +809,37 @@ impl ContextStore {
     }
 
     fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
-        cx.subscribe(
-            &self.project.read(cx).context_server_store(),
-            Self::handle_context_server_event,
-        )
-        .detach();
+        let context_server_store = self.project.read(cx).context_server_store();
+        cx.subscribe(&context_server_store, Self::handle_context_server_event)
+            .detach();
+
+        // Check for any servers that were already running before the handler was registered
+        for server in context_server_store.read(cx).running_servers() {
+            self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
+        }
     }
 
     fn handle_context_server_event(
         &mut self,
-        context_server_manager: Entity<ContextServerStore>,
+        context_server_store: Entity<ContextServerStore>,
         event: &project::context_server_store::Event,
         cx: &mut Context<Self>,
     ) {
-        let slash_command_working_set = self.slash_commands.clone();
         match event {
             project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
                 match status {
                     ContextServerStatus::Running => {
-                        if let Some(server) = context_server_manager
-                            .read(cx)
-                            .get_running_server(server_id)
-                        {
-                            let context_server_manager = context_server_manager.clone();
-                            cx.spawn({
-                                let server = server.clone();
-                                let server_id = server_id.clone();
-                                async move |this, cx| {
-                                    let Some(protocol) = server.client() else {
-                                        return;
-                                    };
-
-                                    if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
-                                        if let Some(prompts) = protocol.list_prompts().await.log_err() {
-                                            let slash_command_ids = prompts
-                                                .into_iter()
-                                                .filter(assistant_slash_commands::acceptable_prompt)
-                                                .map(|prompt| {
-                                                    log::info!(
-                                                        "registering context server command: {:?}",
-                                                        prompt.name
-                                                    );
-                                                    slash_command_working_set.insert(Arc::new(
-                                                        assistant_slash_commands::ContextServerSlashCommand::new(
-                                                            context_server_manager.clone(),
-                                                            server.id(),
-                                                            prompt,
-                                                        ),
-                                                    ))
-                                                })
-                                                .collect::<Vec<_>>();
-
-                                            this.update( cx, |this, _cx| {
-                                                this.context_server_slash_command_ids
-                                                    .insert(server_id.clone(), slash_command_ids);
-                                            })
-                                            .log_err();
-                                        }
-                                    }
-                                }
-                            })
-                            .detach();
-                        }
+                        self.load_context_server_slash_commands(
+                            server_id.clone(),
+                            context_server_store.clone(),
+                            cx,
+                        );
                     }
                     ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
                         if let Some(slash_command_ids) =
                             self.context_server_slash_command_ids.remove(server_id)
                         {
-                            slash_command_working_set.remove(&slash_command_ids);
+                            self.slash_commands.remove(&slash_command_ids);
                         }
                     }
                     _ => {}
@@ -884,4 +847,47 @@ impl ContextStore {
             }
         }
     }
+
+    fn load_context_server_slash_commands(
+        &self,
+        server_id: ContextServerId,
+        context_server_store: Entity<ContextServerStore>,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
+            return;
+        };
+        let slash_command_working_set = self.slash_commands.clone();
+        cx.spawn(async move |this, cx| {
+            let Some(protocol) = server.client() else {
+                return;
+            };
+
+            if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
+                if let Some(prompts) = protocol.list_prompts().await.log_err() {
+                    let slash_command_ids = prompts
+                        .into_iter()
+                        .filter(assistant_slash_commands::acceptable_prompt)
+                        .map(|prompt| {
+                            log::info!("registering context server command: {:?}", prompt.name);
+                            slash_command_working_set.insert(Arc::new(
+                                assistant_slash_commands::ContextServerSlashCommand::new(
+                                    context_server_store.clone(),
+                                    server.id(),
+                                    prompt,
+                                ),
+                            ))
+                        })
+                        .collect::<Vec<_>>();
+
+                    this.update(cx, |this, _cx| {
+                        this.context_server_slash_command_ids
+                            .insert(server_id.clone(), slash_command_ids);
+                    })
+                    .log_err();
+                }
+            }
+        })
+        .detach();
+    }
 }