Merge branch 'mcp-acp-gemini' of github.com:zed-industries/zed into mcp-acp-gemini

Agus Zubiaga created

Change summary

crates/agent_servers/src/acp_connection.rs  | 54 ++++++++++++----------
crates/context_server/src/client.rs         | 14 ++---
crates/context_server/src/context_server.rs | 27 +++++++++-
crates/context_server/src/protocol.rs       |  9 ++-
4 files changed, 63 insertions(+), 41 deletions(-)

Detailed changes

crates/agent_servers/src/acp_connection.rs 🔗

@@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 pub struct AcpConnection {
     auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
     server_name: &'static str,
-    client: Arc<context_server::ContextServer>,
+    context_server: Arc<context_server::ContextServer>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
     _session_update_task: Task<()>,
 }
@@ -34,7 +34,7 @@ impl AcpConnection {
         working_directory: Option<Arc<Path>>,
         cx: &mut AsyncApp,
     ) -> Result<Self> {
-        let client: Arc<ContextServer> = ContextServer::stdio(
+        let context_server: Arc<ContextServer> = ContextServer::stdio(
             ContextServerId(format!("{}-mcp-server", server_name).into()),
             ContextServerCommand {
                 path: command.path,
@@ -44,26 +44,8 @@ impl AcpConnection {
             working_directory,
         )
         .into();
-        ContextServer::start(client.clone(), cx).await?;
-
-        let mcp_client = client.client().context("Failed to subscribe")?;
 
         let (notification_tx, mut notification_rx) = mpsc::unbounded();
-        mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
-            move |notification, _cx| {
-                let notification_tx = notification_tx.clone();
-                log::trace!(
-                    "ACP Notification: {}",
-                    serde_json::to_string_pretty(&notification).unwrap()
-                );
-
-                if let Some(notification) =
-                    serde_json::from_value::<acp::SessionNotification>(notification).log_err()
-                {
-                    notification_tx.unbounded_send(notification).ok();
-                }
-            }
-        });
 
         let sessions = Rc::new(RefCell::new(HashMap::default()));
 
@@ -76,10 +58,32 @@ impl AcpConnection {
             }
         });
 
+        context_server
+            .start_with_handlers(
+                vec![(acp::AGENT_METHODS.session_update, {
+                    Box::new(move |notification, _cx| {
+                        let notification_tx = notification_tx.clone();
+                        log::trace!(
+                            "ACP Notification: {}",
+                            serde_json::to_string_pretty(&notification).unwrap()
+                        );
+
+                        if let Some(notification) =
+                            serde_json::from_value::<acp::SessionNotification>(notification)
+                                .log_err()
+                        {
+                            notification_tx.unbounded_send(notification).ok();
+                        }
+                    })
+                })],
+                cx,
+            )
+            .await?;
+
         Ok(Self {
             auth_methods: Default::default(),
             server_name,
-            client,
+            context_server,
             sessions,
             _session_update_task: session_update_handler_task,
         })
@@ -123,7 +127,7 @@ impl AgentConnection for AcpConnection {
         cwd: &Path,
         cx: &mut AsyncApp,
     ) -> Task<Result<Entity<AcpThread>>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         let sessions = self.sessions.clone();
         let auth_methods = self.auth_methods.clone();
         let cwd = cwd.to_path_buf();
@@ -200,7 +204,7 @@ impl AgentConnection for AcpConnection {
     }
 
     fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         cx.foreground_executor().spawn(async move {
             let params = acp::AuthenticateArguments { method_id };
 
@@ -226,7 +230,7 @@ impl AgentConnection for AcpConnection {
         params: agent_client_protocol::PromptArguments,
         cx: &mut App,
     ) -> Task<Result<()>> {
-        let client = self.client.client();
+        let client = self.context_server.client();
         let sessions = self.sessions.clone();
 
         cx.foreground_executor().spawn(async move {
@@ -283,6 +287,6 @@ impl AgentConnection for AcpConnection {
 
 impl Drop for AcpConnection {
     fn drop(&mut self) {
-        self.client.stop().log_err();
+        self.context_server.stop().log_err();
     }
 }

crates/context_server/src/client.rs 🔗

@@ -441,14 +441,12 @@ impl Client {
         Ok(())
     }
 
-    #[allow(unused)]
-    pub fn on_notification<F>(&self, method: &'static str, f: F)
-    where
-        F: 'static + Send + FnMut(Value, AsyncApp),
-    {
-        self.notification_handlers
-            .lock()
-            .insert(method, Box::new(f));
+    pub fn on_notification(
+        &self,
+        method: &'static str,
+        f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+    ) {
+        self.notification_handlers.lock().insert(method, f);
     }
 }
 

crates/context_server/src/context_server.rs 🔗

@@ -95,8 +95,28 @@ impl ContextServer {
         self.client.read().clone()
     }
 
-    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
-        let client = match &self.configuration {
+    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
+        self.initialize(self.new_client(cx)?).await
+    }
+
+    /// Starts the context server, making sure handlers are registered before initialization happens
+    pub async fn start_with_handlers(
+        &self,
+        notification_handlers: Vec<(
+            &'static str,
+            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
+        )>,
+        cx: &AsyncApp,
+    ) -> Result<()> {
+        let client = self.new_client(cx)?;
+        for (method, handler) in notification_handlers {
+            client.on_notification(method, handler);
+        }
+        self.initialize(client).await
+    }
+
+    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
+        Ok(match &self.configuration {
             ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
                 client::ContextServerId(self.id.0.clone()),
                 client::ModelContextServerBinary {
@@ -113,8 +133,7 @@ impl ContextServer {
                 transport.clone(),
                 cx.clone(),
             )?,
-        };
-        self.initialize(client).await
+        })
     }
 
     async fn initialize(&self, client: Client) -> Result<()> {

crates/context_server/src/protocol.rs 🔗

@@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
         self.inner.notify(T::METHOD, params)
     }
 
-    pub fn on_notification<F>(&self, method: &'static str, f: F)
-    where
-        F: 'static + Send + FnMut(Value, AsyncApp),
-    {
+    pub fn on_notification(
+        &self,
+        method: &'static str,
+        f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+    ) {
         self.inner.on_notification(method, f);
     }
 }