Put context servers behind a trait (#20432)

Marshall Bowers created

This PR puts context servers behind the `ContextServer` trait to allow
us to provide context servers from an extension.

Release Notes:

- N/A

Change summary

Cargo.lock                                                   |   1 
crates/assistant/src/context_store.rs                        |   6 
crates/assistant/src/slash_command/context_server_command.rs |  10 
crates/assistant/src/tools/context_server_tool.rs            |   8 
crates/context_servers/Cargo.toml                            |   1 
crates/context_servers/src/manager.rs                        | 133 +++--
6 files changed, 100 insertions(+), 59 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2815,6 +2815,7 @@ name = "context_servers"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-trait",
  "collections",
  "command_palette_hooks",
  "futures 0.3.30",

crates/assistant/src/context_store.rs 🔗

@@ -819,7 +819,7 @@ impl ContextStore {
             |context_server_manager, cx| {
                 for server in context_server_manager.servers() {
                     context_server_manager
-                        .restart_server(&server.id, cx)
+                        .restart_server(&server.id(), cx)
                         .detach_and_log_err(cx);
                 }
             },
@@ -850,7 +850,7 @@ impl ContextStore {
                         let server = server.clone();
                         let server_id = server_id.clone();
                         |this, mut cx| async move {
-                            let Some(protocol) = server.client.read().clone() else {
+                            let Some(protocol) = server.client() else {
                                 return;
                             };
 
@@ -889,7 +889,7 @@ impl ContextStore {
                                         tool_working_set.insert(
                                             Arc::new(tools::context_server_tool::ContextServerTool::new(
                                                 context_server_manager.clone(),
-                                                server.id.clone(),
+                                                server.id(),
                                                 tool,
                                             )),
                                         )

crates/assistant/src/slash_command/context_server_command.rs 🔗

@@ -20,18 +20,18 @@ use crate::slash_command::create_label_for_command;
 
 pub struct ContextServerSlashCommand {
     server_manager: Model<ContextServerManager>,
-    server_id: String,
+    server_id: Arc<str>,
     prompt: Prompt,
 }
 
 impl ContextServerSlashCommand {
     pub fn new(
         server_manager: Model<ContextServerManager>,
-        server: &Arc<ContextServer>,
+        server: &Arc<dyn ContextServer>,
         prompt: Prompt,
     ) -> Self {
         Self {
-            server_id: server.id.clone(),
+            server_id: server.id(),
             prompt,
             server_manager,
         }
@@ -89,7 +89,7 @@ impl SlashCommand for ContextServerSlashCommand {
 
         if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
             cx.foreground_executor().spawn(async move {
-                let Some(protocol) = server.client.read().clone() else {
+                let Some(protocol) = server.client() else {
                     return Err(anyhow!("Context server not initialized"));
                 };
 
@@ -143,7 +143,7 @@ impl SlashCommand for ContextServerSlashCommand {
         let manager = self.server_manager.read(cx);
         if let Some(server) = manager.get_server(&server_id) {
             cx.foreground_executor().spawn(async move {
-                let Some(protocol) = server.client.read().clone() else {
+                let Some(protocol) = server.client() else {
                     return Err(anyhow!("Context server not initialized"));
                 };
                 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;

crates/assistant/src/tools/context_server_tool.rs 🔗

@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
 use anyhow::{anyhow, bail};
 use assistant_tool::Tool;
 use context_servers::manager::ContextServerManager;
@@ -6,14 +8,14 @@ use gpui::{Model, Task};
 
 pub struct ContextServerTool {
     server_manager: Model<ContextServerManager>,
-    server_id: String,
+    server_id: Arc<str>,
     tool: types::Tool,
 }
 
 impl ContextServerTool {
     pub fn new(
         server_manager: Model<ContextServerManager>,
-        server_id: impl Into<String>,
+        server_id: impl Into<Arc<str>>,
         tool: types::Tool,
     ) -> Self {
         Self {
@@ -55,7 +57,7 @@ impl Tool for ContextServerTool {
             cx.foreground_executor().spawn({
                 let tool_name = self.tool.name.clone();
                 async move {
-                    let Some(protocol) = server.client.read().clone() else {
+                    let Some(protocol) = server.client() else {
                         bail!("Context server not initialized");
                     };
 

crates/context_servers/Cargo.toml 🔗

@@ -13,6 +13,7 @@ path = "src/context_servers.rs"
 
 [dependencies]
 anyhow.workspace = true
+async-trait.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
 futures.workspace = true

crates/context_servers/src/manager.rs 🔗

@@ -15,9 +15,13 @@
 //! and react to changes in settings.
 
 use std::path::Path;
+use std::pin::Pin;
 use std::sync::Arc;
 
+use anyhow::Result;
+use async_trait::async_trait;
 use collections::{HashMap, HashSet};
+use futures::{Future, FutureExt};
 use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
 use log;
 use parking_lot::RwLock;
@@ -56,51 +60,84 @@ impl Settings for ContextServerSettings {
     }
 }
 
-pub struct ContextServer {
-    pub id: String,
-    pub config: ServerConfig,
+#[async_trait(?Send)]
+pub trait ContextServer: Send + Sync + 'static {
+    fn id(&self) -> Arc<str>;
+    fn config(&self) -> Arc<ServerConfig>;
+    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
+    fn start<'a>(
+        self: Arc<Self>,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
+    fn stop(&self) -> Result<()>;
+}
+
+pub struct NativeContextServer {
+    pub id: Arc<str>,
+    pub config: Arc<ServerConfig>,
     pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 }
 
-impl ContextServer {
-    fn new(config: ServerConfig) -> Self {
+impl NativeContextServer {
+    fn new(config: Arc<ServerConfig>) -> Self {
         Self {
-            id: config.id.clone(),
+            id: config.id.clone().into(),
             config,
             client: RwLock::new(None),
         }
     }
+}
 
-    async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
-        log::info!("starting context server {}", self.config.id,);
-        let client = Client::new(
-            client::ContextServerId(self.config.id.clone()),
-            client::ModelContextServerBinary {
-                executable: Path::new(&self.config.executable).to_path_buf(),
-                args: self.config.args.clone(),
-                env: self.config.env.clone(),
-            },
-            cx.clone(),
-        )?;
-
-        let protocol = crate::protocol::ModelContextProtocol::new(client);
-        let client_info = types::Implementation {
-            name: "Zed".to_string(),
-            version: env!("CARGO_PKG_VERSION").to_string(),
-        };
-        let initialized_protocol = protocol.initialize(client_info).await?;
+#[async_trait(?Send)]
+impl ContextServer for NativeContextServer {
+    fn id(&self) -> Arc<str> {
+        self.id.clone()
+    }
 
-        log::debug!(
-            "context server {} initialized: {:?}",
-            self.config.id,
-            initialized_protocol.initialize,
-        );
+    fn config(&self) -> Arc<ServerConfig> {
+        self.config.clone()
+    }
 
-        *self.client.write() = Some(Arc::new(initialized_protocol));
-        Ok(())
+    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
+        self.client.read().clone()
+    }
+
+    fn start<'a>(
+        self: Arc<Self>,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
+        async move {
+            log::info!("starting context server {}", self.config.id,);
+            let client = Client::new(
+                client::ContextServerId(self.config.id.clone()),
+                client::ModelContextServerBinary {
+                    executable: Path::new(&self.config.executable).to_path_buf(),
+                    args: self.config.args.clone(),
+                    env: self.config.env.clone(),
+                },
+                cx.clone(),
+            )?;
+
+            let protocol = crate::protocol::ModelContextProtocol::new(client);
+            let client_info = types::Implementation {
+                name: "Zed".to_string(),
+                version: env!("CARGO_PKG_VERSION").to_string(),
+            };
+            let initialized_protocol = protocol.initialize(client_info).await?;
+
+            log::debug!(
+                "context server {} initialized: {:?}",
+                self.config.id,
+                initialized_protocol.initialize,
+            );
+
+            *self.client.write() = Some(Arc::new(initialized_protocol));
+            Ok(())
+        }
+        .boxed_local()
     }
 
-    async fn stop(&self) -> anyhow::Result<()> {
+    fn stop(&self) -> Result<()> {
         let mut client = self.client.write();
         if let Some(protocol) = client.take() {
             drop(protocol);
@@ -114,7 +151,7 @@ impl ContextServer {
 /// must go through the `GlobalContextServerManager` which holds
 /// a model to the ContextServerManager.
 pub struct ContextServerManager {
-    servers: HashMap<String, Arc<ContextServer>>,
+    servers: HashMap<String, Arc<dyn ContextServer>>,
     pending_servers: HashSet<String>,
 }
 
@@ -141,7 +178,7 @@ impl ContextServerManager {
 
     pub fn add_server(
         &mut self,
-        config: ServerConfig,
+        config: Arc<ServerConfig>,
         cx: &ModelContext<Self>,
     ) -> Task<anyhow::Result<()>> {
         let server_id = config.id.clone();
@@ -153,8 +190,8 @@ impl ContextServerManager {
         let task = {
             let server_id = server_id.clone();
             cx.spawn(|this, mut cx| async move {
-                let server = Arc::new(ContextServer::new(config));
-                server.start(&cx).await?;
+                let server = Arc::new(NativeContextServer::new(config));
+                server.clone().start(&cx).await?;
                 this.update(&mut cx, |this, cx| {
                     this.servers.insert(server_id.clone(), server);
                     this.pending_servers.remove(&server_id);
@@ -170,7 +207,7 @@ impl ContextServerManager {
         task
     }
 
-    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+    pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
         self.servers.get(id).cloned()
     }
 
@@ -178,7 +215,7 @@ impl ContextServerManager {
         let id = id.to_string();
         cx.spawn(|this, mut cx| async move {
             if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
-                server.stop().await?;
+                server.stop()?;
             }
             this.update(&mut cx, |this, cx| {
                 this.pending_servers.remove(&id);
@@ -192,16 +229,16 @@ impl ContextServerManager {
 
     pub fn restart_server(
         &mut self,
-        id: &str,
+        id: &Arc<str>,
         cx: &mut ModelContext<Self>,
     ) -> Task<anyhow::Result<()>> {
         let id = id.to_string();
         cx.spawn(|this, mut cx| async move {
             if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
-                server.stop().await?;
-                let config = server.config.clone();
-                let new_server = Arc::new(ContextServer::new(config));
-                new_server.start(&cx).await?;
+                server.stop()?;
+                let config = server.config();
+                let new_server = Arc::new(NativeContextServer::new(config));
+                new_server.clone().start(&cx).await?;
                 this.update(&mut cx, |this, cx| {
                     this.servers.insert(id.clone(), new_server);
                     cx.emit(Event::ServerStopped {
@@ -216,7 +253,7 @@ impl ContextServerManager {
         })
     }
 
-    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+    pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
         self.servers.values().cloned().collect()
     }
 
@@ -224,7 +261,7 @@ impl ContextServerManager {
         let current_servers = self
             .servers()
             .into_iter()
-            .map(|server| (server.id.clone(), server.config.clone()))
+            .map(|server| (server.id(), server.config()))
             .collect::<HashMap<_, _>>();
 
         let new_servers = settings
@@ -235,19 +272,19 @@ impl ContextServerManager {
 
         let servers_to_add = new_servers
             .values()
-            .filter(|config| !current_servers.contains_key(&config.id))
+            .filter(|config| !current_servers.contains_key(config.id.as_str()))
             .cloned()
             .collect::<Vec<_>>();
 
         let servers_to_remove = current_servers
             .keys()
-            .filter(|id| !new_servers.contains_key(*id))
+            .filter(|id| !new_servers.contains_key(id.as_ref()))
             .cloned()
             .collect::<Vec<_>>();
 
         log::trace!("servers_to_add={:?}", servers_to_add);
         for config in servers_to_add {
-            self.add_server(config, cx).detach_and_log_err(cx);
+            self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
         }
 
         for id in servers_to_remove {