Refactor handling of ContextServer notifications

Ben Brandt created

The notification handler registration is now more explicit, with
handlers set up before server initialization to avoid potential race
conditions.

Change summary

crates/agent_servers/src/acp_connection.rs  | 85 ++++++++++++----------
crates/context_server/src/client.rs         | 14 +--
crates/context_server/src/context_server.rs | 27 ++++++-
crates/context_server/src/protocol.rs       |  9 +-
4 files changed, 79 insertions(+), 56 deletions(-)

Detailed changes

crates/agent_servers/src/acp_connection.rs 🔗

@@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection};
 pub struct AcpConnection {
     agent_state: Rc<RefCell<acp::AgentState>>,
     server_name: &'static str,
-    client: Arc<context_server::ContextServer>,
+    context_server: Arc<context_server::ContextServer>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
     _agent_state_task: Task<()>,
     _session_update_task: Task<()>,
@@ -35,7 +35,7 @@ impl AcpConnection {
         working_directory: Option<Arc<Path>>,
         cx: &mut AsyncApp,
     ) -> Result<Self> {
-        let client: Arc<ContextServer> = ContextServer::stdio(
+        let context_server: Arc<ContextServer> = ContextServer::stdio(
             ContextServerId(format!("{}-mcp-server", server_name).into()),
             ContextServerCommand {
                 path: command.path,
@@ -45,42 +45,9 @@ impl AcpConnection {
             working_directory,
         )
         .into();
-        ContextServer::start(client.clone(), cx).await?;
 
         let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
-        let mcp_client = client.client().context("Failed to subscribe")?;
-
-        mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
-            move |notification, _cx| {
-                log::trace!(
-                    "ACP Notification: {}",
-                    serde_json::to_string_pretty(&notification).unwrap()
-                );
-
-                if let Some(state) =
-                    serde_json::from_value::<acp::AgentState>(notification).log_err()
-                {
-                    state_tx.send(state).log_err();
-                }
-            }
-        });
-
         let (notification_tx, mut notification_rx) = mpsc::unbounded();
-        mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
-            move |notification, _cx| {
-                let notification_tx = notification_tx.clone();
-                log::trace!(
-                    "ACP Notification: {}",
-                    serde_json::to_string_pretty(&notification).unwrap()
-                );
-
-                if let Some(notification) =
-                    serde_json::from_value::<acp::SessionNotification>(notification).log_err()
-                {
-                    notification_tx.unbounded_send(notification).ok();
-                }
-            }
-        });
 
         let sessions = Rc::new(RefCell::new(HashMap::default()));
         let initial_state = state_rx.recv().await?;
@@ -104,9 +71,47 @@ impl AcpConnection {
             }
         });
 
+        context_server
+            .start_with_handlers(
+                vec![
+                    (acp::AGENT_METHODS.agent_state, {
+                        Box::new(move |notification, _cx| {
+                            log::trace!(
+                                "ACP Notification: {}",
+                                serde_json::to_string_pretty(&notification).unwrap()
+                            );
+
+                            if let Some(state) =
+                                serde_json::from_value::<acp::AgentState>(notification).log_err()
+                            {
+                                state_tx.send(state).log_err();
+                            }
+                        })
+                    }),
+                    (acp::AGENT_METHODS.session_update, {
+                        Box::new(move |notification, _cx| {
+                            let notification_tx = notification_tx.clone();
+                            log::trace!(
+                                "ACP Notification: {}",
+                                serde_json::to_string_pretty(&notification).unwrap()
+                            );
+
+                            if let Some(notification) =
+                                serde_json::from_value::<acp::SessionNotification>(notification)
+                                    .log_err()
+                            {
+                                notification_tx.unbounded_send(notification).ok();
+                            }
+                        })
+                    }),
+                ],
+                cx,
+            )
+            .await?;
+
         Ok(Self {
             server_name,
-            client,
+            context_server,
             sessions,
             agent_state,
             _agent_state_task: agent_state_task,
@@ -152,7 +157,7 @@ impl AgentConnection for AcpConnection {
         cwd: &Path,
         cx: &mut AsyncApp,
     ) -> Task<Result<Entity<AcpThread>>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         let sessions = self.sessions.clone();
         let cwd = cwd.to_path_buf();
         cx.spawn(async move |cx| {
@@ -222,7 +227,7 @@ impl AgentConnection for AcpConnection {
     }
 
     fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         cx.foreground_executor().spawn(async move {
             let params = acp::AuthenticateArguments { method_id };
 
@@ -248,7 +253,7 @@ impl AgentConnection for AcpConnection {
         params: agent_client_protocol::PromptArguments,
         cx: &mut App,
     ) -> Task<Result<()>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         let sessions = self.sessions.clone();
 
         cx.foreground_executor().spawn(async move {
@@ -305,6 +310,6 @@ impl AgentConnection for AcpConnection {
 
 impl Drop for AcpConnection {
     fn drop(&mut self) {
-        self.client.stop().log_err();
+        self.context_server.stop().log_err();
     }
 }

crates/context_server/src/client.rs 🔗

@@ -441,14 +441,12 @@ impl Client {
         Ok(())
     }
 
-    #[allow(unused)]
-    pub fn on_notification<F>(&self, method: &'static str, f: F)
-    where
-        F: 'static + Send + FnMut(Value, AsyncApp),
-    {
-        self.notification_handlers
-            .lock()
-            .insert(method, Box::new(f));
+    pub fn on_notification(
+        &self,
+        method: &'static str,
+        f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+    ) {
+        self.notification_handlers.lock().insert(method, f);
     }
 }
 

crates/context_server/src/context_server.rs 🔗

@@ -95,8 +95,28 @@ impl ContextServer {
         self.client.read().clone()
     }
 
-    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
-        let client = match &self.configuration {
+    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
+        self.initialize(self.new_client(cx)?).await
+    }
+
+    /// Starts the context server, making sure handlers are registered before initialization happens
+    pub async fn start_with_handlers(
+        &self,
+        notification_handlers: Vec<(
+            &'static str,
+            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
+        )>,
+        cx: &AsyncApp,
+    ) -> Result<()> {
+        let client = self.new_client(cx)?;
+        for (method, handler) in notification_handlers {
+            client.on_notification(method, handler);
+        }
+        self.initialize(client).await
+    }
+
+    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
+        Ok(match &self.configuration {
             ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
                 client::ContextServerId(self.id.0.clone()),
                 client::ModelContextServerBinary {
@@ -113,8 +133,7 @@ impl ContextServer {
                 transport.clone(),
                 cx.clone(),
             )?,
-        };
-        self.initialize(client).await
+        })
     }
 
     async fn initialize(&self, client: Client) -> Result<()> {

crates/context_server/src/protocol.rs 🔗

@@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
         self.inner.notify(T::METHOD, params)
     }
 
-    pub fn on_notification<F>(&self, method: &'static str, f: F)
-    where
-        F: 'static + Send + FnMut(Value, AsyncApp),
-    {
+    pub fn on_notification(
+        &self,
+        method: &'static str,
+        f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+    ) {
         self.inner.on_notification(method, f);
     }
 }