Improve context server lifecycle management (#20622)

Max Brunsfeld , Marshall , and Marshall Bowers created

This optimizes and fixes bugs in our logic for maintaining a set of
running context servers, based on the combination of the user's
`context_servers` settings and their installed extensions.

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

Cargo.lock                                                   |   4 
crates/assistant/src/context_store.rs                        |  87 -
crates/assistant/src/slash_command/context_server_command.rs |   2 
crates/collab/Cargo.toml                                     |   1 
crates/collab/src/tests/integration_tests.rs                 |   2 
crates/command_palette_hooks/src/command_palette_hooks.rs    |   8 
crates/context_servers/Cargo.toml                            |   1 
crates/context_servers/src/context_servers.rs                |   1 
crates/context_servers/src/manager.rs                        | 331 +++--
crates/context_servers/src/registry.rs                       |  56 
crates/extension_host/src/extension_host.rs                  |   4 
crates/extensions_ui/Cargo.toml                              |   2 
crates/extensions_ui/src/extension_context_server.rs         |  97 -
crates/extensions_ui/src/extension_registration_hooks.rs     |  87 +
crates/extensions_ui/src/extension_store_test.rs             |   4 
crates/extensions_ui/src/extensions_ui.rs                    |   1 
16 files changed, 279 insertions(+), 409 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2572,6 +2572,7 @@ dependencies = [
  "clock",
  "collab_ui",
  "collections",
+ "context_servers",
  "ctor",
  "dashmap 6.0.1",
  "derive_more",
@@ -2818,7 +2819,6 @@ name = "context_servers"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "async-trait",
  "collections",
  "command_palette_hooks",
  "futures 0.3.30",
@@ -4205,7 +4205,6 @@ dependencies = [
  "assistant_slash_command",
  "async-compression",
  "async-tar",
- "async-trait",
  "client",
  "collections",
  "context_servers",
@@ -4222,6 +4221,7 @@ dependencies = [
  "http_client",
  "indexed_docs",
  "language",
+ "log",
  "lsp",
  "node_runtime",
  "num-format",

crates/assistant/src/context_store.rs 🔗

@@ -8,9 +8,8 @@ use anyhow::{anyhow, Context as _, Result};
 use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
 use clock::ReplicaId;
 use collections::HashMap;
-use command_palette_hooks::CommandPaletteFilter;
-use context_servers::manager::{ContextServerManager, ContextServerSettings};
-use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
+use context_servers::manager::ContextServerManager;
+use context_servers::ContextServerFactoryRegistry;
 use fs::Fs;
 use futures::StreamExt;
 use fuzzy::StringMatchCandidate;
@@ -22,7 +21,6 @@ use paths::contexts_dir;
 use project::Project;
 use regex::Regex;
 use rpc::AnyProtoClient;
-use settings::{Settings as _, SettingsStore};
 use std::{
     cmp::Reverse,
     ffi::OsStr,
@@ -111,7 +109,11 @@ impl ContextStore {
             let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
 
             let this = cx.new_model(|cx: &mut ModelContext<Self>| {
-                let context_server_manager = cx.new_model(|_cx| ContextServerManager::new());
+                let context_server_factory_registry =
+                    ContextServerFactoryRegistry::default_global(cx);
+                let context_server_manager = cx.new_model(|cx| {
+                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
+                });
                 let mut this = Self {
                     contexts: Vec::new(),
                     contexts_metadata: Vec::new(),
@@ -148,91 +150,16 @@ impl ContextStore {
                 this.handle_project_changed(project.clone(), cx);
                 this.synchronize_contexts(cx);
                 this.register_context_server_handlers(cx);
-
-                if project.read(cx).is_local() {
-                    // TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
-                    // In order to register the context servers when the extension is loaded, we're periodically looping to
-                    // see if there are context servers to register.
-                    //
-                    // I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
-                    //
-                    // We should find a more elegant way to do this.
-                    let context_server_factory_registry =
-                        ContextServerFactoryRegistry::default_global(cx);
-                    cx.spawn(|context_store, mut cx| async move {
-                        loop {
-                            let mut servers_to_register = Vec::new();
-                            for (_id, factory) in
-                                context_server_factory_registry.context_server_factories()
-                            {
-                                if let Some(server) = factory(project.clone(), &cx).await.log_err()
-                                {
-                                    servers_to_register.push(server);
-                                }
-                            }
-
-                            let Some(_) = context_store
-                                .update(&mut cx, |this, cx| {
-                                    this.context_server_manager.update(cx, |this, cx| {
-                                        for server in servers_to_register {
-                                            this.add_server(server, cx).detach_and_log_err(cx);
-                                        }
-                                    })
-                                })
-                                .log_err()
-                            else {
-                                break;
-                            };
-
-                            smol::Timer::after(Duration::from_millis(100)).await;
-                        }
-
-                        anyhow::Ok(())
-                    })
-                    .detach_and_log_err(cx);
-                }
-
                 this
             })?;
             this.update(&mut cx, |this, cx| this.reload(cx))?
                 .await
                 .log_err();
 
-            this.update(&mut cx, |this, cx| {
-                this.watch_context_server_settings(cx);
-            })
-            .log_err();
-
             Ok(this)
         })
     }
 
-    fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
-        cx.observe_global::<SettingsStore>(move |this, cx| {
-            this.context_server_manager.update(cx, |manager, cx| {
-                let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
-                    settings::SettingsLocation {
-                        worktree_id: worktree.read(cx).id(),
-                        path: Path::new(""),
-                    }
-                });
-                let settings = ContextServerSettings::get(location, cx);
-
-                manager.maintain_servers(settings, cx);
-
-                let has_any_context_servers = !manager.servers().is_empty();
-                CommandPaletteFilter::update_global(cx, |filter, _cx| {
-                    if has_any_context_servers {
-                        filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
-                    } else {
-                        filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
-                    }
-                });
-            })
-        })
-        .detach();
-    }
-
     async fn handle_advertise_contexts(
         this: Model<Self>,
         envelope: TypedEnvelope<proto::AdvertiseContexts>,

crates/collab/Cargo.toml 🔗

@@ -78,6 +78,7 @@ uuid.workspace = true
 
 [dev-dependencies]
 assistant = { workspace = true, features = ["test-support"] }
+context_servers.workspace = true
 async-trait.workspace = true
 audio.workspace = true
 call = { workspace = true, features = ["test-support"] }

crates/collab/src/tests/integration_tests.rs 🔗

@@ -6486,6 +6486,8 @@ async fn test_context_collaboration_with_reconnect(
         assert_eq!(project.collaborators().len(), 1);
     });
 
+    cx_a.update(context_servers::init);
+    cx_b.update(context_servers::init);
     let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
     let context_store_a = cx_a
         .update(|cx| {

crates/command_palette_hooks/src/command_palette_hooks.rs 🔗

@@ -39,11 +39,13 @@ impl CommandPaletteFilter {
     }
 
     /// Updates the global [`CommandPaletteFilter`] using the given closure.
-    pub fn update_global<F, R>(cx: &mut AppContext, update: F) -> R
+    pub fn update_global<F>(cx: &mut AppContext, update: F)
     where
-        F: FnOnce(&mut Self, &mut AppContext) -> R,
+        F: FnOnce(&mut Self, &mut AppContext),
     {
-        cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx))
+        if cx.has_global::<GlobalCommandPaletteFilter>() {
+            cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx))
+        }
     }
 
     /// Returns whether the given [`Action`] is hidden by the filter.

crates/context_servers/Cargo.toml 🔗

@@ -13,7 +13,6 @@ path = "src/context_servers.rs"
 
 [dependencies]
 anyhow.workspace = true
-async-trait.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
 futures.workspace = true

crates/context_servers/src/context_servers.rs 🔗

@@ -8,7 +8,6 @@ use command_palette_hooks::CommandPaletteFilter;
 use gpui::{actions, AppContext};
 use settings::Settings;
 
-pub use crate::manager::ContextServer;
 use crate::manager::ContextServerSettings;
 pub use crate::registry::ContextServerFactoryRegistry;
 

crates/context_servers/src/manager.rs 🔗

@@ -15,23 +15,23 @@
 //! and react to changes in settings.
 
 use std::path::Path;
-use std::pin::Pin;
 use std::sync::Arc;
 
 use anyhow::{bail, Result};
-use async_trait::async_trait;
-use collections::{HashMap, HashSet};
-use futures::{Future, FutureExt};
-use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
+use collections::HashMap;
+use command_palette_hooks::CommandPaletteFilter;
+use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
 use log;
 use parking_lot::RwLock;
+use project::Project;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use settings::{Settings, SettingsSources};
+use settings::{Settings, SettingsSources, SettingsStore};
+use util::ResultExt as _;
 
 use crate::{
     client::{self, Client},
-    types,
+    types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
 };
 
 #[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
@@ -66,25 +66,13 @@ impl Settings for ContextServerSettings {
     }
 }
 
-#[async_trait(?Send)]
-pub trait ContextServer: Send + Sync + 'static {
-    fn id(&self) -> Arc<str>;
-    fn config(&self) -> Arc<ServerConfig>;
-    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
-    fn start<'a>(
-        self: Arc<Self>,
-        cx: &'a AsyncAppContext,
-    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
-    fn stop(&self) -> Result<()>;
-}
-
-pub struct NativeContextServer {
+pub struct ContextServer {
     pub id: Arc<str>,
     pub config: Arc<ServerConfig>,
     pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 }
 
-impl NativeContextServer {
+impl ContextServer {
     pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
         Self {
             id,
@@ -92,61 +80,52 @@ impl NativeContextServer {
             client: RwLock::new(None),
         }
     }
-}
 
-#[async_trait(?Send)]
-impl ContextServer for NativeContextServer {
-    fn id(&self) -> Arc<str> {
+    pub fn id(&self) -> Arc<str> {
         self.id.clone()
     }
 
-    fn config(&self) -> Arc<ServerConfig> {
+    pub fn config(&self) -> Arc<ServerConfig> {
         self.config.clone()
     }
 
-    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
+    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
         self.client.read().clone()
     }
 
-    fn start<'a>(
-        self: Arc<Self>,
-        cx: &'a AsyncAppContext,
-    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
-        async move {
-            log::info!("starting context server {}", self.id);
-            let Some(command) = &self.config.command else {
-                bail!("no command specified for server {}", self.id);
-            };
-            let client = Client::new(
-                client::ContextServerId(self.id.clone()),
-                client::ModelContextServerBinary {
-                    executable: Path::new(&command.path).to_path_buf(),
-                    args: command.args.clone(),
-                    env: command.env.clone(),
-                },
-                cx.clone(),
-            )?;
-
-            let protocol = crate::protocol::ModelContextProtocol::new(client);
-            let client_info = types::Implementation {
-                name: "Zed".to_string(),
-                version: env!("CARGO_PKG_VERSION").to_string(),
-            };
-            let initialized_protocol = protocol.initialize(client_info).await?;
-
-            log::debug!(
-                "context server {} initialized: {:?}",
-                self.id,
-                initialized_protocol.initialize,
-            );
-
-            *self.client.write() = Some(Arc::new(initialized_protocol));
-            Ok(())
-        }
-        .boxed_local()
+    pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
+        log::info!("starting context server {}", self.id);
+        let Some(command) = &self.config.command else {
+            bail!("no command specified for server {}", self.id);
+        };
+        let client = Client::new(
+            client::ContextServerId(self.id.clone()),
+            client::ModelContextServerBinary {
+                executable: Path::new(&command.path).to_path_buf(),
+                args: command.args.clone(),
+                env: command.env.clone(),
+            },
+            cx.clone(),
+        )?;
+
+        let protocol = crate::protocol::ModelContextProtocol::new(client);
+        let client_info = types::Implementation {
+            name: "Zed".to_string(),
+            version: env!("CARGO_PKG_VERSION").to_string(),
+        };
+        let initialized_protocol = protocol.initialize(client_info).await?;
+
+        log::debug!(
+            "context server {} initialized: {:?}",
+            self.id,
+            initialized_protocol.initialize,
+        );
+
+        *self.client.write() = Some(Arc::new(initialized_protocol));
+        Ok(())
     }
 
-    fn stop(&self) -> Result<()> {
+    pub fn stop(&self) -> Result<()> {
         let mut client = self.client.write();
         if let Some(protocol) = client.take() {
             drop(protocol);
@@ -155,13 +134,13 @@ impl ContextServer for NativeContextServer {
     }
 }
 
-/// A Context server manager manages the starting and stopping
-/// of all servers. To obtain a server to interact with, a crate
-/// must go through the `GlobalContextServerManager` which holds
-/// a model to the ContextServerManager.
 pub struct ContextServerManager {
-    servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
-    pending_servers: HashSet<Arc<str>>,
+    servers: HashMap<Arc<str>, Arc<ContextServer>>,
+    project: Model<Project>,
+    registry: Model<ContextServerFactoryRegistry>,
+    update_servers_task: Option<Task<Result<()>>>,
+    needs_server_update: bool,
+    _subscriptions: Vec<Subscription>,
 }
 
 pub enum Event {
@@ -171,74 +150,66 @@ pub enum Event {
 
 impl EventEmitter<Event> for ContextServerManager {}
 
-impl Default for ContextServerManager {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
 impl ContextServerManager {
-    pub fn new() -> Self {
-        Self {
+    pub fn new(
+        registry: Model<ContextServerFactoryRegistry>,
+        project: Model<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
+        let mut this = Self {
+            _subscriptions: vec![
+                cx.observe(&registry, |this, _registry, cx| {
+                    this.available_context_servers_changed(cx);
+                }),
+                cx.observe_global::<SettingsStore>(|this, cx| {
+                    this.available_context_servers_changed(cx);
+                }),
+            ],
+            project,
+            registry,
+            needs_server_update: false,
             servers: HashMap::default(),
-            pending_servers: HashSet::default(),
-        }
+            update_servers_task: None,
+        };
+        this.available_context_servers_changed(cx);
+        this
     }
 
-    pub fn add_server(
-        &mut self,
-        server: Arc<dyn ContextServer>,
-        cx: &ModelContext<Self>,
-    ) -> Task<anyhow::Result<()>> {
-        let server_id = server.id();
+    fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
+        if self.update_servers_task.is_some() {
+            self.needs_server_update = true;
+        } else {
+            self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
+                this.update(&mut cx, |this, _| {
+                    this.needs_server_update = false;
+                })?;
 
-        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
-            return Task::ready(Ok(()));
-        }
+                Self::maintain_servers(this.clone(), cx.clone()).await?;
 
-        let task = {
-            let server_id = server_id.clone();
-            cx.spawn(|this, mut cx| async move {
-                server.clone().start(&cx).await?;
                 this.update(&mut cx, |this, cx| {
-                    this.servers.insert(server_id.clone(), server);
-                    this.pending_servers.remove(&server_id);
-                    cx.emit(Event::ServerStarted {
-                        server_id: server_id.clone(),
-                    });
+                    let has_any_context_servers = !this.servers().is_empty();
+                    if has_any_context_servers {
+                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
+                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
+                        });
+                    }
+
+                    this.update_servers_task.take();
+                    if this.needs_server_update {
+                        this.available_context_servers_changed(cx);
+                    }
                 })?;
-                Ok(())
-            })
-        };
 
-        self.pending_servers.insert(server_id);
-        task
-    }
-
-    pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
-        self.servers.get(id).cloned()
+                Ok(())
+            }));
+        }
     }
 
-    pub fn remove_server(
-        &mut self,
-        id: &Arc<str>,
-        cx: &ModelContext<Self>,
-    ) -> Task<anyhow::Result<()>> {
-        let id = id.clone();
-        cx.spawn(|this, mut cx| async move {
-            if let Some(server) =
-                this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
-            {
-                server.stop()?;
-            }
-            this.update(&mut cx, |this, cx| {
-                this.pending_servers.remove(id.as_ref());
-                cx.emit(Event::ServerStopped {
-                    server_id: id.clone(),
-                })
-            })?;
-            Ok(())
-        })
+    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+        self.servers
+            .get(id)
+            .filter(|server| server.client().is_some())
+            .cloned()
     }
 
     pub fn restart_server(
@@ -251,7 +222,7 @@ impl ContextServerManager {
             if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
                 server.stop()?;
                 let config = server.config();
-                let new_server = Arc::new(NativeContextServer::new(id.clone(), config));
+                let new_server = Arc::new(ContextServer::new(id.clone(), config));
                 new_server.clone().start(&cx).await?;
                 this.update(&mut cx, |this, cx| {
                     this.servers.insert(id.clone(), new_server);
@@ -267,45 +238,83 @@ impl ContextServerManager {
         })
     }
 
-    pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
-        self.servers.values().cloned().collect()
+    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+        self.servers
+            .values()
+            .filter(|server| server.client().is_some())
+            .cloned()
+            .collect()
     }
 
-    pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
-        let current_servers = self
-            .servers()
-            .into_iter()
-            .map(|server| (server.id(), server.config()))
-            .collect::<HashMap<_, _>>();
-
-        let new_servers = settings
-            .context_servers
-            .iter()
-            .map(|(id, config)| (id.clone(), config.clone()))
-            .collect::<HashMap<_, _>>();
-
-        let servers_to_add = new_servers
-            .iter()
-            .filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
-            .map(|(id, config)| (id.clone(), config.clone()))
-            .collect::<Vec<_>>();
-
-        let servers_to_remove = current_servers
-            .keys()
-            .filter(|id| !new_servers.contains_key(id.as_ref()))
-            .cloned()
-            .collect::<Vec<_>>();
+    async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
+        let mut desired_servers = HashMap::default();
+
+        let (registry, project) = this.update(&mut cx, |this, cx| {
+            let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
+                settings::SettingsLocation {
+                    worktree_id: worktree.read(cx).id(),
+                    path: Path::new(""),
+                }
+            });
+            let settings = ContextServerSettings::get(location, cx);
+            desired_servers = settings.context_servers.clone();
+
+            (this.registry.clone(), this.project.clone())
+        })?;
+
+        for (id, factory) in
+            registry.read_with(&cx, |registry, _| registry.context_server_factories())?
+        {
+            let config = desired_servers.entry(id).or_default();
+            if config.command.is_none() {
+                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
+                    config.command = Some(extension_command);
+                }
+            }
+        }
 
-        log::trace!("servers_to_add={:?}", servers_to_add);
-        for (id, config) in servers_to_add {
-            if config.command.is_some() {
-                let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
-                self.add_server(server, cx).detach_and_log_err(cx);
+        let mut servers_to_start = HashMap::default();
+        let mut servers_to_stop = HashMap::default();
+
+        this.update(&mut cx, |this, _cx| {
+            this.servers.retain(|id, server| {
+                if desired_servers.contains_key(id) {
+                    true
+                } else {
+                    servers_to_stop.insert(id.clone(), server.clone());
+                    false
+                }
+            });
+
+            for (id, config) in desired_servers {
+                let existing_config = this.servers.get(&id).map(|server| server.config());
+                if existing_config.as_deref() != Some(&config) {
+                    let config = Arc::new(config);
+                    let server = Arc::new(ContextServer::new(id.clone(), config));
+                    servers_to_start.insert(id.clone(), server.clone());
+                    let old_server = this.servers.insert(id.clone(), server);
+                    if let Some(old_server) = old_server {
+                        servers_to_stop.insert(id, old_server);
+                    }
+                }
             }
+        })?;
+
+        for (id, server) in servers_to_stop {
+            server.stop().log_err();
+            this.update(&mut cx, |_, cx| {
+                cx.emit(Event::ServerStopped { server_id: id })
+            })?;
         }
 
-        for id in servers_to_remove {
-            self.remove_server(&id, cx).detach_and_log_err(cx);
+        for (id, server) in servers_to_start {
+            if server.start(&cx).await.log_err().is_some() {
+                this.update(&mut cx, |_, cx| {
+                    cx.emit(Event::ServerStarted { server_id: id })
+                })?;
+            }
         }
+
+        Ok(())
     }
 }

crates/context_servers/src/registry.rs 🔗

@@ -2,75 +2,61 @@ use std::sync::Arc;
 
 use anyhow::Result;
 use collections::HashMap;
-use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task};
-use parking_lot::RwLock;
+use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task};
 use project::Project;
 
-use crate::ContextServer;
+use crate::manager::ServerCommand;
 
 pub type ContextServerFactory = Arc<
-    dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>>
-        + Send
-        + Sync
-        + 'static,
+    dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<ServerCommand>> + Send + Sync + 'static,
 >;
 
-#[derive(Default)]
-struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
+struct GlobalContextServerFactoryRegistry(Model<ContextServerFactoryRegistry>);
 
 impl Global for GlobalContextServerFactoryRegistry {}
 
-#[derive(Default)]
-struct ContextServerFactoryRegistryState {
-    context_servers: HashMap<Arc<str>, ContextServerFactory>,
-}
-
 #[derive(Default)]
 pub struct ContextServerFactoryRegistry {
-    state: RwLock<ContextServerFactoryRegistryState>,
+    context_servers: HashMap<Arc<str>, ContextServerFactory>,
 }
 
 impl ContextServerFactoryRegistry {
     /// Returns the global [`ContextServerFactoryRegistry`].
-    pub fn global(cx: &AppContext) -> Arc<Self> {
+    pub fn global(cx: &AppContext) -> Model<Self> {
         GlobalContextServerFactoryRegistry::global(cx).0.clone()
     }
 
     /// Returns the global [`ContextServerFactoryRegistry`].
     ///
     /// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
-    pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
-        cx.default_global::<GlobalContextServerFactoryRegistry>()
-            .0
-            .clone()
+    pub fn default_global(cx: &mut AppContext) -> Model<Self> {
+        if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
+            let registry = cx.new_model(|_| Self::new());
+            cx.set_global(GlobalContextServerFactoryRegistry(registry));
+        }
+        cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
     }
 
-    pub fn new() -> Arc<Self> {
-        Arc::new(Self {
-            state: RwLock::new(ContextServerFactoryRegistryState {
-                context_servers: HashMap::default(),
-            }),
-        })
+    pub fn new() -> Self {
+        Self {
+            context_servers: HashMap::default(),
+        }
     }
 
     pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
-        self.state
-            .read()
-            .context_servers
+        self.context_servers
             .iter()
             .map(|(id, factory)| (id.clone(), factory.clone()))
             .collect()
     }
 
     /// Registers the provided [`ContextServerFactory`].
-    pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) {
-        let mut state = self.state.write();
-        state.context_servers.insert(id, factory);
+    pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
+        self.context_servers.insert(id, factory);
     }
 
     /// Unregisters the [`ContextServerFactory`] for the server with the given ID.
-    pub fn unregister_server_factory_by_id(&self, server_id: &str) {
-        let mut state = self.state.write();
-        state.context_servers.remove(server_id);
+    pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
+        self.context_servers.remove(server_id);
     }
 }

crates/extension_host/src/extension_host.rs 🔗

@@ -141,7 +141,7 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
         &self,
         _id: Arc<str>,
         _extension: WasmExtension,
-        _host: Arc<WasmHost>,
+        _cx: &mut AppContext,
     ) {
     }
 
@@ -1266,7 +1266,7 @@ impl ExtensionStore {
                         this.registration_hooks.register_context_server(
                             id.clone(),
                             wasm_extension.clone(),
-                            this.wasm_host.clone(),
+                            cx,
                         );
                     }
 

crates/extensions_ui/Cargo.toml 🔗

@@ -17,7 +17,6 @@ test-support = []
 [dependencies]
 anyhow.workspace = true
 assistant_slash_command.workspace = true
-async-trait.workspace = true
 client.workspace = true
 collections.workspace = true
 context_servers.workspace = true
@@ -31,6 +30,7 @@ fuzzy.workspace = true
 gpui.workspace = true
 indexed_docs.workspace = true
 language.workspace = true
+log.workspace = true
 lsp.workspace = true
 num-format.workspace = true
 picker.workspace = true

crates/extensions_ui/src/extension_context_server.rs 🔗

@@ -1,97 +0,0 @@
-use std::pin::Pin;
-use std::sync::Arc;
-
-use anyhow::{anyhow, Result};
-use async_trait::async_trait;
-use context_servers::manager::{NativeContextServer, ServerCommand, ServerConfig};
-use context_servers::protocol::InitializedContextServerProtocol;
-use context_servers::ContextServer;
-use extension_host::wasm_host::{ExtensionProject, WasmExtension, WasmHost};
-use futures::{Future, FutureExt};
-use gpui::{AsyncAppContext, Model};
-use project::Project;
-use wasmtime_wasi::WasiView as _;
-
-pub struct ExtensionContextServer {
-    #[allow(unused)]
-    pub(crate) extension: WasmExtension,
-    #[allow(unused)]
-    pub(crate) host: Arc<WasmHost>,
-    id: Arc<str>,
-    context_server: Arc<NativeContextServer>,
-}
-
-impl ExtensionContextServer {
-    pub async fn new(
-        extension: WasmExtension,
-        host: Arc<WasmHost>,
-        id: Arc<str>,
-        project: Model<Project>,
-        mut cx: AsyncAppContext,
-    ) -> Result<Self> {
-        let extension_project = project.update(&mut cx, |project, cx| ExtensionProject {
-            worktree_ids: project
-                .visible_worktrees(cx)
-                .map(|worktree| worktree.read(cx).id().to_proto())
-                .collect(),
-        })?;
-        let command = extension
-            .call({
-                let id = id.clone();
-                |extension, store| {
-                    async move {
-                        let project = store.data_mut().table().push(extension_project)?;
-                        let command = extension
-                            .call_context_server_command(store, id.clone(), project)
-                            .await?
-                            .map_err(|e| anyhow!("{}", e))?;
-                        anyhow::Ok(command)
-                    }
-                    .boxed()
-                }
-            })
-            .await?;
-
-        let config = Arc::new(ServerConfig {
-            settings: None,
-            command: Some(ServerCommand {
-                path: command.command,
-                args: command.args,
-                env: Some(command.env.into_iter().collect()),
-            }),
-        });
-
-        anyhow::Ok(Self {
-            extension,
-            host,
-            id: id.clone(),
-            context_server: Arc::new(NativeContextServer::new(id, config)),
-        })
-    }
-}
-
-#[async_trait(?Send)]
-impl ContextServer for ExtensionContextServer {
-    fn id(&self) -> Arc<str> {
-        self.id.clone()
-    }
-
-    fn config(&self) -> Arc<ServerConfig> {
-        self.context_server.config()
-    }
-
-    fn client(&self) -> Option<Arc<InitializedContextServerProtocol>> {
-        self.context_server.client()
-    }
-
-    fn start<'a>(
-        self: Arc<Self>,
-        cx: &'a AsyncAppContext,
-    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
-        self.context_server.clone().start(cx)
-    }
-
-    fn stop(&self) -> Result<()> {
-        self.context_server.stop()
-    }
-}

crates/extensions_ui/src/extension_registration_hooks.rs 🔗

@@ -1,19 +1,21 @@
 use std::{path::PathBuf, sync::Arc};
 
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry};
+use context_servers::manager::ServerCommand;
 use context_servers::ContextServerFactoryRegistry;
+use db::smol::future::FutureExt as _;
 use extension::Extension;
+use extension_host::wasm_host::ExtensionProject;
 use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
 use fs::Fs;
-use gpui::{AppContext, BackgroundExecutor, Task};
+use gpui::{AppContext, BackgroundExecutor, Model, Task};
 use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId};
 use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage};
 use snippet_provider::SnippetRegistry;
 use theme::{ThemeRegistry, ThemeSettings};
 use ui::SharedString;
-
-use crate::extension_context_server::ExtensionContextServer;
+use wasmtime_wasi::WasiView as _;
 
 pub struct ConcreteExtensionRegistrationHooks {
     slash_command_registry: Arc<SlashCommandRegistry>,
@@ -21,7 +23,7 @@ pub struct ConcreteExtensionRegistrationHooks {
     indexed_docs_registry: Arc<IndexedDocsRegistry>,
     snippet_registry: Arc<SnippetRegistry>,
     language_registry: Arc<LanguageRegistry>,
-    context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
+    context_server_factory_registry: Model<ContextServerFactoryRegistry>,
     executor: BackgroundExecutor,
 }
 
@@ -32,7 +34,7 @@ impl ConcreteExtensionRegistrationHooks {
         indexed_docs_registry: Arc<IndexedDocsRegistry>,
         snippet_registry: Arc<SnippetRegistry>,
         language_registry: Arc<LanguageRegistry>,
-        context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
+        context_server_factory_registry: Model<ContextServerFactoryRegistry>,
         cx: &AppContext,
     ) -> Arc<dyn extension_host::ExtensionRegistrationHooks> {
         Arc::new(Self {
@@ -71,25 +73,66 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
         &self,
         id: Arc<str>,
         extension: wasm_host::WasmExtension,
-        host: Arc<wasm_host::WasmHost>,
+        cx: &mut AppContext,
     ) {
         self.context_server_factory_registry
-            .register_server_factory(
-                id.clone(),
-                Arc::new({
-                    move |project, cx| {
-                        let id = id.clone();
-                        let extension = extension.clone();
-                        let host = host.clone();
-                        cx.spawn(|cx| async move {
-                            let context_server =
-                                ExtensionContextServer::new(extension, host, id, project, cx)
+            .update(cx, |registry, _| {
+                registry.register_server_factory(
+                    id.clone(),
+                    Arc::new({
+                        move |project, cx| {
+                            log::info!(
+                                "loading command for context server {id} from extension {}",
+                                extension.manifest.id
+                            );
+
+                            let id = id.clone();
+                            let extension = extension.clone();
+                            cx.spawn(|mut cx| async move {
+                                let extension_project =
+                                    project.update(&mut cx, |project, cx| ExtensionProject {
+                                        worktree_ids: project
+                                            .visible_worktrees(cx)
+                                            .map(|worktree| worktree.read(cx).id().to_proto())
+                                            .collect(),
+                                    })?;
+
+                                let command = extension
+                                    .call({
+                                        let id = id.clone();
+                                        |extension, store| {
+                                            async move {
+                                                let project = store
+                                                    .data_mut()
+                                                    .table()
+                                                    .push(extension_project)?;
+                                                let command = extension
+                                                    .call_context_server_command(
+                                                        store,
+                                                        id.clone(),
+                                                        project,
+                                                    )
+                                                    .await?
+                                                    .map_err(|e| anyhow!("{}", e))?;
+                                                anyhow::Ok(command)
+                                            }
+                                            .boxed()
+                                        }
+                                    })
                                     .await?;
-                            anyhow::Ok(Arc::new(context_server) as _)
-                        })
-                    }
-                }),
-            );
+
+                                log::info!("loaded command for context server {id}: {command:?}");
+
+                                Ok(ServerCommand {
+                                    path: command.command,
+                                    args: command.args,
+                                    env: Some(command.env.into_iter().collect()),
+                                })
+                            })
+                        }
+                    }),
+                )
+            });
     }
 
     fn register_docs_provider(&self, extension: Arc<dyn Extension>, provider_id: Arc<str>) {

crates/extensions_ui/src/extension_store_test.rs 🔗

@@ -268,7 +268,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
     let slash_command_registry = SlashCommandRegistry::new();
     let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
     let snippet_registry = Arc::new(SnippetRegistry::new());
-    let context_server_factory_registry = ContextServerFactoryRegistry::new();
+    let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
     let node_runtime = NodeRuntime::unavailable();
 
     let store = cx.new_model(|cx| {
@@ -508,7 +508,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
     let slash_command_registry = SlashCommandRegistry::new();
     let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
     let snippet_registry = Arc::new(SnippetRegistry::new());
-    let context_server_factory_registry = ContextServerFactoryRegistry::new();
+    let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
     let node_runtime = NodeRuntime::unavailable();
 
     let mut status_updates = language_registry.language_server_binary_statuses();