Detailed changes
@@ -2572,6 +2572,7 @@ dependencies = [
"clock",
"collab_ui",
"collections",
+ "context_servers",
"ctor",
"dashmap 6.0.1",
"derive_more",
@@ -2818,7 +2819,6 @@ name = "context_servers"
version = "0.1.0"
dependencies = [
"anyhow",
- "async-trait",
"collections",
"command_palette_hooks",
"futures 0.3.30",
@@ -4216,7 +4216,6 @@ dependencies = [
"assistant_slash_command",
"async-compression",
"async-tar",
- "async-trait",
"client",
"collections",
"context_servers",
@@ -4233,6 +4232,7 @@ dependencies = [
"http_client",
"indexed_docs",
"language",
+ "log",
"lsp",
"node_runtime",
"num-format",
@@ -8,9 +8,8 @@ use anyhow::{anyhow, Context as _, Result};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
use clock::ReplicaId;
use collections::HashMap;
-use command_palette_hooks::CommandPaletteFilter;
-use context_servers::manager::{ContextServerManager, ContextServerSettings};
-use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
+use context_servers::manager::ContextServerManager;
+use context_servers::ContextServerFactoryRegistry;
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@@ -22,7 +21,6 @@ use paths::contexts_dir;
use project::Project;
use regex::Regex;
use rpc::AnyProtoClient;
-use settings::{Settings as _, SettingsStore};
use std::{
cmp::Reverse,
ffi::OsStr,
@@ -111,7 +109,11 @@ impl ContextStore {
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
- let context_server_manager = cx.new_model(|_cx| ContextServerManager::new());
+ let context_server_factory_registry =
+ ContextServerFactoryRegistry::default_global(cx);
+ let context_server_manager = cx.new_model(|cx| {
+ ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
+ });
let mut this = Self {
contexts: Vec::new(),
contexts_metadata: Vec::new(),
@@ -148,91 +150,16 @@ impl ContextStore {
this.handle_project_changed(project.clone(), cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);
-
- if project.read(cx).is_local() {
- // TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
- // In order to register the context servers when the extension is loaded, we're periodically looping to
- // see if there are context servers to register.
- //
- // I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
- //
- // We should find a more elegant way to do this.
- let context_server_factory_registry =
- ContextServerFactoryRegistry::default_global(cx);
- cx.spawn(|context_store, mut cx| async move {
- loop {
- let mut servers_to_register = Vec::new();
- for (_id, factory) in
- context_server_factory_registry.context_server_factories()
- {
- if let Some(server) = factory(project.clone(), &cx).await.log_err()
- {
- servers_to_register.push(server);
- }
- }
-
- let Some(_) = context_store
- .update(&mut cx, |this, cx| {
- this.context_server_manager.update(cx, |this, cx| {
- for server in servers_to_register {
- this.add_server(server, cx).detach_and_log_err(cx);
- }
- })
- })
- .log_err()
- else {
- break;
- };
-
- smol::Timer::after(Duration::from_millis(100)).await;
- }
-
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
-
this
})?;
this.update(&mut cx, |this, cx| this.reload(cx))?
.await
.log_err();
- this.update(&mut cx, |this, cx| {
- this.watch_context_server_settings(cx);
- })
- .log_err();
-
Ok(this)
})
}
- fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
- cx.observe_global::<SettingsStore>(move |this, cx| {
- this.context_server_manager.update(cx, |manager, cx| {
- let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
- settings::SettingsLocation {
- worktree_id: worktree.read(cx).id(),
- path: Path::new(""),
- }
- });
- let settings = ContextServerSettings::get(location, cx);
-
- manager.maintain_servers(settings, cx);
-
- let has_any_context_servers = !manager.servers().is_empty();
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- if has_any_context_servers {
- filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
- } else {
- filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
- }
- });
- })
- })
- .detach();
- }
-
async fn handle_advertise_contexts(
this: Model<Self>,
envelope: TypedEnvelope<proto::AdvertiseContexts>,
@@ -27,7 +27,7 @@ pub struct ContextServerSlashCommand {
impl ContextServerSlashCommand {
pub fn new(
server_manager: Model<ContextServerManager>,
- server: &Arc<dyn ContextServer>,
+ server: &Arc<ContextServer>,
prompt: Prompt,
) -> Self {
Self {
@@ -78,6 +78,7 @@ uuid.workspace = true
[dev-dependencies]
assistant = { workspace = true, features = ["test-support"] }
+context_servers.workspace = true
async-trait.workspace = true
audio.workspace = true
call = { workspace = true, features = ["test-support"] }
@@ -6486,6 +6486,8 @@ async fn test_context_collaboration_with_reconnect(
assert_eq!(project.collaborators().len(), 1);
});
+ cx_a.update(context_servers::init);
+ cx_b.update(context_servers::init);
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a
.update(|cx| {
@@ -39,11 +39,13 @@ impl CommandPaletteFilter {
}
/// Updates the global [`CommandPaletteFilter`] using the given closure.
- pub fn update_global<F, R>(cx: &mut AppContext, update: F) -> R
+ pub fn update_global<F>(cx: &mut AppContext, update: F)
where
- F: FnOnce(&mut Self, &mut AppContext) -> R,
+ F: FnOnce(&mut Self, &mut AppContext),
{
- cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx))
+ if cx.has_global::<GlobalCommandPaletteFilter>() {
+ cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx))
+ }
}
/// Returns whether the given [`Action`] is hidden by the filter.
@@ -13,7 +13,6 @@ path = "src/context_servers.rs"
[dependencies]
anyhow.workspace = true
-async-trait.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
futures.workspace = true
@@ -8,7 +8,6 @@ use command_palette_hooks::CommandPaletteFilter;
use gpui::{actions, AppContext};
use settings::Settings;
-pub use crate::manager::ContextServer;
use crate::manager::ContextServerSettings;
pub use crate::registry::ContextServerFactoryRegistry;
@@ -15,23 +15,23 @@
//! and react to changes in settings.
use std::path::Path;
-use std::pin::Pin;
use std::sync::Arc;
use anyhow::{bail, Result};
-use async_trait::async_trait;
-use collections::{HashMap, HashSet};
-use futures::{Future, FutureExt};
-use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
+use collections::HashMap;
+use command_palette_hooks::CommandPaletteFilter;
+use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
use log;
use parking_lot::RwLock;
+use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use settings::{Settings, SettingsSources};
+use settings::{Settings, SettingsSources, SettingsStore};
+use util::ResultExt as _;
use crate::{
client::{self, Client},
- types,
+ types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
};
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
@@ -66,25 +66,13 @@ impl Settings for ContextServerSettings {
}
}
-#[async_trait(?Send)]
-pub trait ContextServer: Send + Sync + 'static {
- fn id(&self) -> Arc<str>;
- fn config(&self) -> Arc<ServerConfig>;
- fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
- fn start<'a>(
- self: Arc<Self>,
- cx: &'a AsyncAppContext,
- ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
- fn stop(&self) -> Result<()>;
-}
-
-pub struct NativeContextServer {
+pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
}
-impl NativeContextServer {
+impl ContextServer {
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
Self {
id,
@@ -92,61 +80,52 @@ impl NativeContextServer {
client: RwLock::new(None),
}
}
-}
-#[async_trait(?Send)]
-impl ContextServer for NativeContextServer {
- fn id(&self) -> Arc<str> {
+ pub fn id(&self) -> Arc<str> {
self.id.clone()
}
- fn config(&self) -> Arc<ServerConfig> {
+ pub fn config(&self) -> Arc<ServerConfig> {
self.config.clone()
}
- fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
+ pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
- fn start<'a>(
- self: Arc<Self>,
- cx: &'a AsyncAppContext,
- ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
- async move {
- log::info!("starting context server {}", self.id);
- let Some(command) = &self.config.command else {
- bail!("no command specified for server {}", self.id);
- };
- let client = Client::new(
- client::ContextServerId(self.id.clone()),
- client::ModelContextServerBinary {
- executable: Path::new(&command.path).to_path_buf(),
- args: command.args.clone(),
- env: command.env.clone(),
- },
- cx.clone(),
- )?;
-
- let protocol = crate::protocol::ModelContextProtocol::new(client);
- let client_info = types::Implementation {
- name: "Zed".to_string(),
- version: env!("CARGO_PKG_VERSION").to_string(),
- };
- let initialized_protocol = protocol.initialize(client_info).await?;
-
- log::debug!(
- "context server {} initialized: {:?}",
- self.id,
- initialized_protocol.initialize,
- );
-
- *self.client.write() = Some(Arc::new(initialized_protocol));
- Ok(())
- }
- .boxed_local()
+ pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
+ log::info!("starting context server {}", self.id);
+ let Some(command) = &self.config.command else {
+ bail!("no command specified for server {}", self.id);
+ };
+ let client = Client::new(
+ client::ContextServerId(self.id.clone()),
+ client::ModelContextServerBinary {
+ executable: Path::new(&command.path).to_path_buf(),
+ args: command.args.clone(),
+ env: command.env.clone(),
+ },
+ cx.clone(),
+ )?;
+
+ let protocol = crate::protocol::ModelContextProtocol::new(client);
+ let client_info = types::Implementation {
+ name: "Zed".to_string(),
+ version: env!("CARGO_PKG_VERSION").to_string(),
+ };
+ let initialized_protocol = protocol.initialize(client_info).await?;
+
+ log::debug!(
+ "context server {} initialized: {:?}",
+ self.id,
+ initialized_protocol.initialize,
+ );
+
+ *self.client.write() = Some(Arc::new(initialized_protocol));
+ Ok(())
}
- fn stop(&self) -> Result<()> {
+ pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
@@ -155,13 +134,13 @@ impl ContextServer for NativeContextServer {
}
}
-/// A Context server manager manages the starting and stopping
-/// of all servers. To obtain a server to interact with, a crate
-/// must go through the `GlobalContextServerManager` which holds
-/// a model to the ContextServerManager.
pub struct ContextServerManager {
- servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
- pending_servers: HashSet<Arc<str>>,
+ servers: HashMap<Arc<str>, Arc<ContextServer>>,
+ project: Model<Project>,
+ registry: Model<ContextServerFactoryRegistry>,
+ update_servers_task: Option<Task<Result<()>>>,
+ needs_server_update: bool,
+ _subscriptions: Vec<Subscription>,
}
pub enum Event {
@@ -171,74 +150,66 @@ pub enum Event {
impl EventEmitter<Event> for ContextServerManager {}
-impl Default for ContextServerManager {
- fn default() -> Self {
- Self::new()
- }
-}
-
impl ContextServerManager {
- pub fn new() -> Self {
- Self {
+ pub fn new(
+ registry: Model<ContextServerFactoryRegistry>,
+ project: Model<Project>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let mut this = Self {
+ _subscriptions: vec![
+ cx.observe(®istry, |this, _registry, cx| {
+ this.available_context_servers_changed(cx);
+ }),
+ cx.observe_global::<SettingsStore>(|this, cx| {
+ this.available_context_servers_changed(cx);
+ }),
+ ],
+ project,
+ registry,
+ needs_server_update: false,
servers: HashMap::default(),
- pending_servers: HashSet::default(),
- }
+ update_servers_task: None,
+ };
+ this.available_context_servers_changed(cx);
+ this
}
- pub fn add_server(
- &mut self,
- server: Arc<dyn ContextServer>,
- cx: &ModelContext<Self>,
- ) -> Task<anyhow::Result<()>> {
- let server_id = server.id();
+ fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
+ if self.update_servers_task.is_some() {
+ self.needs_server_update = true;
+ } else {
+ self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
+ this.update(&mut cx, |this, _| {
+ this.needs_server_update = false;
+ })?;
- if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
- return Task::ready(Ok(()));
- }
+ Self::maintain_servers(this.clone(), cx.clone()).await?;
- let task = {
- let server_id = server_id.clone();
- cx.spawn(|this, mut cx| async move {
- server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
- this.servers.insert(server_id.clone(), server);
- this.pending_servers.remove(&server_id);
- cx.emit(Event::ServerStarted {
- server_id: server_id.clone(),
- });
+ let has_any_context_servers = !this.servers().is_empty();
+ if has_any_context_servers {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
+ });
+ }
+
+ this.update_servers_task.take();
+ if this.needs_server_update {
+ this.available_context_servers_changed(cx);
+ }
})?;
- Ok(())
- })
- };
- self.pending_servers.insert(server_id);
- task
- }
-
- pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
- self.servers.get(id).cloned()
+ Ok(())
+ }));
+ }
}
- pub fn remove_server(
- &mut self,
- id: &Arc<str>,
- cx: &ModelContext<Self>,
- ) -> Task<anyhow::Result<()>> {
- let id = id.clone();
- cx.spawn(|this, mut cx| async move {
- if let Some(server) =
- this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
- {
- server.stop()?;
- }
- this.update(&mut cx, |this, cx| {
- this.pending_servers.remove(id.as_ref());
- cx.emit(Event::ServerStopped {
- server_id: id.clone(),
- })
- })?;
- Ok(())
- })
+ pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+ self.servers
+ .get(id)
+ .filter(|server| server.client().is_some())
+ .cloned()
}
pub fn restart_server(
@@ -251,7 +222,7 @@ impl ContextServerManager {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
- let new_server = Arc::new(NativeContextServer::new(id.clone(), config));
+ let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
@@ -267,45 +238,83 @@ impl ContextServerManager {
})
}
- pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
- self.servers.values().cloned().collect()
+ pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+ self.servers
+ .values()
+ .filter(|server| server.client().is_some())
+ .cloned()
+ .collect()
}
- pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
- let current_servers = self
- .servers()
- .into_iter()
- .map(|server| (server.id(), server.config()))
- .collect::<HashMap<_, _>>();
-
- let new_servers = settings
- .context_servers
- .iter()
- .map(|(id, config)| (id.clone(), config.clone()))
- .collect::<HashMap<_, _>>();
-
- let servers_to_add = new_servers
- .iter()
- .filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
- .map(|(id, config)| (id.clone(), config.clone()))
- .collect::<Vec<_>>();
-
- let servers_to_remove = current_servers
- .keys()
- .filter(|id| !new_servers.contains_key(id.as_ref()))
- .cloned()
- .collect::<Vec<_>>();
+ async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
+ let mut desired_servers = HashMap::default();
+
+ let (registry, project) = this.update(&mut cx, |this, cx| {
+ let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
+ settings::SettingsLocation {
+ worktree_id: worktree.read(cx).id(),
+ path: Path::new(""),
+ }
+ });
+ let settings = ContextServerSettings::get(location, cx);
+ desired_servers = settings.context_servers.clone();
+
+ (this.registry.clone(), this.project.clone())
+ })?;
+
+ for (id, factory) in
+ registry.read_with(&cx, |registry, _| registry.context_server_factories())?
+ {
+ let config = desired_servers.entry(id).or_default();
+ if config.command.is_none() {
+ if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
+ config.command = Some(extension_command);
+ }
+ }
+ }
- log::trace!("servers_to_add={:?}", servers_to_add);
- for (id, config) in servers_to_add {
- if config.command.is_some() {
- let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
- self.add_server(server, cx).detach_and_log_err(cx);
+ let mut servers_to_start = HashMap::default();
+ let mut servers_to_stop = HashMap::default();
+
+ this.update(&mut cx, |this, _cx| {
+ this.servers.retain(|id, server| {
+ if desired_servers.contains_key(id) {
+ true
+ } else {
+ servers_to_stop.insert(id.clone(), server.clone());
+ false
+ }
+ });
+
+ for (id, config) in desired_servers {
+ let existing_config = this.servers.get(&id).map(|server| server.config());
+ if existing_config.as_deref() != Some(&config) {
+ let config = Arc::new(config);
+ let server = Arc::new(ContextServer::new(id.clone(), config));
+ servers_to_start.insert(id.clone(), server.clone());
+ let old_server = this.servers.insert(id.clone(), server);
+ if let Some(old_server) = old_server {
+ servers_to_stop.insert(id, old_server);
+ }
+ }
}
+ })?;
+
+ for (id, server) in servers_to_stop {
+ server.stop().log_err();
+ this.update(&mut cx, |_, cx| {
+ cx.emit(Event::ServerStopped { server_id: id })
+ })?;
}
- for id in servers_to_remove {
- self.remove_server(&id, cx).detach_and_log_err(cx);
+ for (id, server) in servers_to_start {
+ if server.start(&cx).await.log_err().is_some() {
+ this.update(&mut cx, |_, cx| {
+ cx.emit(Event::ServerStarted { server_id: id })
+ })?;
+ }
}
+
+ Ok(())
}
}
@@ -2,75 +2,61 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
-use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task};
-use parking_lot::RwLock;
+use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task};
use project::Project;
-use crate::ContextServer;
+use crate::manager::ServerCommand;
pub type ContextServerFactory = Arc<
- dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>>
- + Send
- + Sync
- + 'static,
+ dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<ServerCommand>> + Send + Sync + 'static,
>;
-#[derive(Default)]
-struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
+struct GlobalContextServerFactoryRegistry(Model<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
-#[derive(Default)]
-struct ContextServerFactoryRegistryState {
- context_servers: HashMap<Arc<str>, ContextServerFactory>,
-}
-
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
- state: RwLock<ContextServerFactoryRegistryState>,
+ context_servers: HashMap<Arc<str>, ContextServerFactory>,
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
- pub fn global(cx: &AppContext) -> Arc<Self> {
+ pub fn global(cx: &AppContext) -> Model<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
- pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
- cx.default_global::<GlobalContextServerFactoryRegistry>()
- .0
- .clone()
+ pub fn default_global(cx: &mut AppContext) -> Model<Self> {
+ if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
+ let registry = cx.new_model(|_| Self::new());
+ cx.set_global(GlobalContextServerFactoryRegistry(registry));
+ }
+ cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
}
- pub fn new() -> Arc<Self> {
- Arc::new(Self {
- state: RwLock::new(ContextServerFactoryRegistryState {
- context_servers: HashMap::default(),
- }),
- })
+ pub fn new() -> Self {
+ Self {
+ context_servers: HashMap::default(),
+ }
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
- self.state
- .read()
- .context_servers
+ self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
- pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) {
- let mut state = self.state.write();
- state.context_servers.insert(id, factory);
+ pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
+ self.context_servers.insert(id, factory);
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
- pub fn unregister_server_factory_by_id(&self, server_id: &str) {
- let mut state = self.state.write();
- state.context_servers.remove(server_id);
+ pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
+ self.context_servers.remove(server_id);
}
}
@@ -141,7 +141,7 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
&self,
_id: Arc<str>,
_extension: WasmExtension,
- _host: Arc<WasmHost>,
+ _cx: &mut AppContext,
) {
}
@@ -1266,7 +1266,7 @@ impl ExtensionStore {
this.registration_hooks.register_context_server(
id.clone(),
wasm_extension.clone(),
- this.wasm_host.clone(),
+ cx,
);
}
@@ -17,7 +17,6 @@ test-support = []
[dependencies]
anyhow.workspace = true
assistant_slash_command.workspace = true
-async-trait.workspace = true
client.workspace = true
collections.workspace = true
context_servers.workspace = true
@@ -31,6 +30,7 @@ fuzzy.workspace = true
gpui.workspace = true
indexed_docs.workspace = true
language.workspace = true
+log.workspace = true
lsp.workspace = true
num-format.workspace = true
picker.workspace = true
@@ -1,97 +0,0 @@
-use std::pin::Pin;
-use std::sync::Arc;
-
-use anyhow::{anyhow, Result};
-use async_trait::async_trait;
-use context_servers::manager::{NativeContextServer, ServerCommand, ServerConfig};
-use context_servers::protocol::InitializedContextServerProtocol;
-use context_servers::ContextServer;
-use extension_host::wasm_host::{ExtensionProject, WasmExtension, WasmHost};
-use futures::{Future, FutureExt};
-use gpui::{AsyncAppContext, Model};
-use project::Project;
-use wasmtime_wasi::WasiView as _;
-
-pub struct ExtensionContextServer {
- #[allow(unused)]
- pub(crate) extension: WasmExtension,
- #[allow(unused)]
- pub(crate) host: Arc<WasmHost>,
- id: Arc<str>,
- context_server: Arc<NativeContextServer>,
-}
-
-impl ExtensionContextServer {
- pub async fn new(
- extension: WasmExtension,
- host: Arc<WasmHost>,
- id: Arc<str>,
- project: Model<Project>,
- mut cx: AsyncAppContext,
- ) -> Result<Self> {
- let extension_project = project.update(&mut cx, |project, cx| ExtensionProject {
- worktree_ids: project
- .visible_worktrees(cx)
- .map(|worktree| worktree.read(cx).id().to_proto())
- .collect(),
- })?;
- let command = extension
- .call({
- let id = id.clone();
- |extension, store| {
- async move {
- let project = store.data_mut().table().push(extension_project)?;
- let command = extension
- .call_context_server_command(store, id.clone(), project)
- .await?
- .map_err(|e| anyhow!("{}", e))?;
- anyhow::Ok(command)
- }
- .boxed()
- }
- })
- .await?;
-
- let config = Arc::new(ServerConfig {
- settings: None,
- command: Some(ServerCommand {
- path: command.command,
- args: command.args,
- env: Some(command.env.into_iter().collect()),
- }),
- });
-
- anyhow::Ok(Self {
- extension,
- host,
- id: id.clone(),
- context_server: Arc::new(NativeContextServer::new(id, config)),
- })
- }
-}
-
-#[async_trait(?Send)]
-impl ContextServer for ExtensionContextServer {
- fn id(&self) -> Arc<str> {
- self.id.clone()
- }
-
- fn config(&self) -> Arc<ServerConfig> {
- self.context_server.config()
- }
-
- fn client(&self) -> Option<Arc<InitializedContextServerProtocol>> {
- self.context_server.client()
- }
-
- fn start<'a>(
- self: Arc<Self>,
- cx: &'a AsyncAppContext,
- ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
- self.context_server.clone().start(cx)
- }
-
- fn stop(&self) -> Result<()> {
- self.context_server.stop()
- }
-}
@@ -1,19 +1,21 @@
use std::{path::PathBuf, sync::Arc};
-use anyhow::Result;
+use anyhow::{anyhow, Result};
use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry};
+use context_servers::manager::ServerCommand;
use context_servers::ContextServerFactoryRegistry;
+use db::smol::future::FutureExt as _;
use extension::Extension;
+use extension_host::wasm_host::ExtensionProject;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs;
-use gpui::{AppContext, BackgroundExecutor, Task};
+use gpui::{AppContext, BackgroundExecutor, Model, Task};
use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId};
use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage};
use snippet_provider::SnippetRegistry;
use theme::{ThemeRegistry, ThemeSettings};
use ui::SharedString;
-
-use crate::extension_context_server::ExtensionContextServer;
+use wasmtime_wasi::WasiView as _;
pub struct ConcreteExtensionRegistrationHooks {
slash_command_registry: Arc<SlashCommandRegistry>,
@@ -21,7 +23,7 @@ pub struct ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
- context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
+ context_server_factory_registry: Model<ContextServerFactoryRegistry>,
executor: BackgroundExecutor,
}
@@ -32,7 +34,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
- context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
+ context_server_factory_registry: Model<ContextServerFactoryRegistry>,
cx: &AppContext,
) -> Arc<dyn extension_host::ExtensionRegistrationHooks> {
Arc::new(Self {
@@ -71,25 +73,66 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
&self,
id: Arc<str>,
extension: wasm_host::WasmExtension,
- host: Arc<wasm_host::WasmHost>,
+ cx: &mut AppContext,
) {
self.context_server_factory_registry
- .register_server_factory(
- id.clone(),
- Arc::new({
- move |project, cx| {
- let id = id.clone();
- let extension = extension.clone();
- let host = host.clone();
- cx.spawn(|cx| async move {
- let context_server =
- ExtensionContextServer::new(extension, host, id, project, cx)
+ .update(cx, |registry, _| {
+ registry.register_server_factory(
+ id.clone(),
+ Arc::new({
+ move |project, cx| {
+ log::info!(
+ "loading command for context server {id} from extension {}",
+ extension.manifest.id
+ );
+
+ let id = id.clone();
+ let extension = extension.clone();
+ cx.spawn(|mut cx| async move {
+ let extension_project =
+ project.update(&mut cx, |project, cx| ExtensionProject {
+ worktree_ids: project
+ .visible_worktrees(cx)
+ .map(|worktree| worktree.read(cx).id().to_proto())
+ .collect(),
+ })?;
+
+ let command = extension
+ .call({
+ let id = id.clone();
+ |extension, store| {
+ async move {
+ let project = store
+ .data_mut()
+ .table()
+ .push(extension_project)?;
+ let command = extension
+ .call_context_server_command(
+ store,
+ id.clone(),
+ project,
+ )
+ .await?
+ .map_err(|e| anyhow!("{}", e))?;
+ anyhow::Ok(command)
+ }
+ .boxed()
+ }
+ })
.await?;
- anyhow::Ok(Arc::new(context_server) as _)
- })
- }
- }),
- );
+
+ log::info!("loaded command for context server {id}: {command:?}");
+
+ Ok(ServerCommand {
+ path: command.command,
+ args: command.args,
+ env: Some(command.env.into_iter().collect()),
+ })
+ })
+ }
+ }),
+ )
+ });
}
fn register_docs_provider(&self, extension: Arc<dyn Extension>, provider_id: Arc<str>) {
@@ -268,7 +268,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
- let context_server_factory_registry = ContextServerFactoryRegistry::new();
+ let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
let node_runtime = NodeRuntime::unavailable();
let store = cx.new_model(|cx| {
@@ -508,7 +508,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
- let context_server_factory_registry = ContextServerFactoryRegistry::new();
+ let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
let node_runtime = NodeRuntime::unavailable();
let mut status_updates = language_registry.language_server_binary_statuses();
@@ -1,5 +1,4 @@
mod components;
-mod extension_context_server;
mod extension_registration_hooks;
mod extension_suggest;
mod extension_version_selector;