Detailed changes
@@ -4212,6 +4212,7 @@ dependencies = [
"async-trait",
"client",
"collections",
+ "context_servers",
"ctor",
"db",
"editor",
@@ -15353,6 +15354,7 @@ dependencies = [
"collections",
"command_palette",
"command_palette_hooks",
+ "context_servers",
"copilot",
"db",
"diagnostics",
@@ -10,7 +10,7 @@ use clock::ReplicaId;
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use context_servers::manager::{ContextServerManager, ContextServerSettings};
-use context_servers::CONTEXT_SERVERS_NAMESPACE;
+use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@@ -51,8 +51,8 @@ pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
context_server_manager: Model<ContextServerManager>,
- context_server_slash_command_ids: HashMap<String, Vec<SlashCommandId>>,
- context_server_tool_ids: HashMap<String, Vec<ToolId>>,
+ context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
+ context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
@@ -148,6 +148,47 @@ impl ContextStore {
this.handle_project_changed(project, cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);
+
+ // TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
+ // In order to register the context servers when the extension is loaded, we're periodically looping to
+ // see if there are context servers to register.
+ //
+ // I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
+ //
+ // We should find a more elegant way to do this.
+ let context_server_factory_registry =
+ ContextServerFactoryRegistry::default_global(cx);
+ cx.spawn(|context_store, mut cx| async move {
+ loop {
+ let mut servers_to_register = Vec::new();
+ for (_id, factory) in
+ context_server_factory_registry.context_server_factories()
+ {
+ if let Some(server) = factory(&cx).await.log_err() {
+ servers_to_register.push(server);
+ }
+ }
+
+ let Some(_) = context_store
+ .update(&mut cx, |this, cx| {
+ this.context_server_manager.update(cx, |this, cx| {
+ for server in servers_to_register {
+ this.add_server(server, cx).detach_and_log_err(cx);
+ }
+ })
+ })
+ .log_err()
+ else {
+ break;
+ };
+
+ smol::Timer::after(Duration::from_millis(100)).await;
+ }
+
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+
this
})?;
this.update(&mut cx, |this, cx| this.reload(cx))?
@@ -1,13 +1,16 @@
+pub mod client;
+pub mod manager;
+pub mod protocol;
+mod registry;
+pub mod types;
+
use command_palette_hooks::CommandPaletteFilter;
use gpui::{actions, AppContext};
use settings::Settings;
+pub use crate::manager::ContextServer;
use crate::manager::ContextServerSettings;
-
-pub mod client;
-pub mod manager;
-pub mod protocol;
-pub mod types;
+pub use crate::registry::ContextServerFactoryRegistry;
actions!(context_servers, [Restart]);
@@ -16,6 +19,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut AppContext) {
ContextServerSettings::register(cx);
+ ContextServerFactoryRegistry::default_global(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
@@ -79,7 +79,7 @@ pub struct NativeContextServer {
}
impl NativeContextServer {
- fn new(config: Arc<ServerConfig>) -> Self {
+ pub fn new(config: Arc<ServerConfig>) -> Self {
Self {
id: config.id.clone().into(),
config,
@@ -151,13 +151,13 @@ impl ContextServer for NativeContextServer {
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager {
- servers: HashMap<String, Arc<dyn ContextServer>>,
- pending_servers: HashSet<String>,
+ servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
+ pending_servers: HashSet<Arc<str>>,
}
pub enum Event {
- ServerStarted { server_id: String },
- ServerStopped { server_id: String },
+ ServerStarted { server_id: Arc<str> },
+ ServerStopped { server_id: Arc<str> },
}
impl EventEmitter<Event> for ContextServerManager {}
@@ -178,10 +178,10 @@ impl ContextServerManager {
pub fn add_server(
&mut self,
- config: Arc<ServerConfig>,
+ server: Arc<dyn ContextServer>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
- let server_id = config.id.clone();
+ let server_id = server.id();
if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
return Task::ready(Ok(()));
@@ -190,7 +190,6 @@ impl ContextServerManager {
let task = {
let server_id = server_id.clone();
cx.spawn(|this, mut cx| async move {
- let server = Arc::new(NativeContextServer::new(config));
server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server);
@@ -211,14 +210,20 @@ impl ContextServerManager {
self.servers.get(id).cloned()
}
- pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
- let id = id.to_string();
+ pub fn remove_server(
+ &mut self,
+ id: &Arc<str>,
+ cx: &ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let id = id.clone();
cx.spawn(|this, mut cx| async move {
- if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
+ if let Some(server) =
+ this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
+ {
server.stop()?;
}
this.update(&mut cx, |this, cx| {
- this.pending_servers.remove(&id);
+ this.pending_servers.remove(id.as_ref());
cx.emit(Event::ServerStopped {
server_id: id.clone(),
})
@@ -232,7 +237,7 @@ impl ContextServerManager {
id: &Arc<str>,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
- let id = id.to_string();
+ let id = id.clone();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
@@ -284,7 +289,8 @@ impl ContextServerManager {
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
- self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
+ let server = Arc::new(NativeContextServer::new(Arc::new(config)));
+ self.add_server(server, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {
@@ -0,0 +1,72 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use collections::HashMap;
+use gpui::{AppContext, AsyncAppContext, ReadGlobal};
+use gpui::{Global, Task};
+use parking_lot::RwLock;
+
+use crate::ContextServer;
+
+pub type ContextServerFactory =
+ Arc<dyn Fn(&AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>> + Send + Sync + 'static>;
+
+#[derive(Default)]
+struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
+
+impl Global for GlobalContextServerFactoryRegistry {}
+
+#[derive(Default)]
+struct ContextServerFactoryRegistryState {
+ context_servers: HashMap<Arc<str>, ContextServerFactory>,
+}
+
+#[derive(Default)]
+pub struct ContextServerFactoryRegistry {
+ state: RwLock<ContextServerFactoryRegistryState>,
+}
+
+impl ContextServerFactoryRegistry {
+ /// Returns the global [`ContextServerFactoryRegistry`].
+ pub fn global(cx: &AppContext) -> Arc<Self> {
+ GlobalContextServerFactoryRegistry::global(cx).0.clone()
+ }
+
+ /// Returns the global [`ContextServerFactoryRegistry`].
+ ///
+ /// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
+ pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
+ cx.default_global::<GlobalContextServerFactoryRegistry>()
+ .0
+ .clone()
+ }
+
+ pub fn new() -> Arc<Self> {
+ Arc::new(Self {
+ state: RwLock::new(ContextServerFactoryRegistryState {
+ context_servers: HashMap::default(),
+ }),
+ })
+ }
+
+ pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
+ self.state
+ .read()
+ .context_servers
+ .iter()
+ .map(|(id, factory)| (id.clone(), factory.clone()))
+ .collect()
+ }
+
+ /// Registers the provided [`ContextServerFactory`].
+ pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) {
+ let mut state = self.state.write();
+ state.context_servers.insert(id, factory);
+ }
+
+ /// Unregisters the [`ContextServerFactory`] for the server with the given ID.
+ pub fn unregister_server_factory_by_id(&self, server_id: &str) {
+ let mut state = self.state.write();
+ state.context_servers.remove(server_id);
+ }
+}
@@ -75,6 +75,8 @@ pub struct ExtensionManifest {
#[serde(default)]
pub language_servers: BTreeMap<LanguageServerName, LanguageServerManifestEntry>,
#[serde(default)]
+ pub context_servers: BTreeMap<Arc<str>, ContextServerManifestEntry>,
+ #[serde(default)]
pub slash_commands: BTreeMap<Arc<str>, SlashCommandManifestEntry>,
#[serde(default)]
pub indexed_docs_providers: BTreeMap<Arc<str>, IndexedDocsProviderEntry>,
@@ -134,6 +136,9 @@ impl LanguageServerManifestEntry {
}
}
+#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub struct ContextServerManifestEntry {}
+
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct SlashCommandManifestEntry {
pub description: String,
@@ -205,6 +210,7 @@ fn manifest_from_old_manifest(
.map(|grammar_name| (grammar_name, Default::default()))
.collect(),
language_servers: Default::default(),
+ context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@@ -129,6 +129,11 @@ pub trait Extension: Send + Sync {
Err("`run_slash_command` not implemented".to_string())
}
+ /// Returns the command used to start a context server.
+ fn context_server_command(&mut self, _context_server_id: &ContextServerId) -> Result<Command> {
+ Err("`context_server_command` not implemented".to_string())
+ }
+
/// Returns a list of package names as suggestions to be included in the
/// search results of the `/docs` slash command.
///
@@ -270,6 +275,11 @@ impl wit::Guest for Component {
extension().run_slash_command(command, args, worktree)
}
+ fn context_server_command(context_server_id: String) -> Result<wit::Command> {
+ let context_server_id = ContextServerId(context_server_id);
+ extension().context_server_command(&context_server_id)
+ }
+
fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
extension().suggest_docs_packages(provider)
}
@@ -299,6 +309,22 @@ impl fmt::Display for LanguageServerId {
}
}
+/// The ID of a context server.
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
+pub struct ContextServerId(String);
+
+impl AsRef<str> for ContextServerId {
+ fn as_ref(&self) -> &str {
+ &self.0
+ }
+}
+
+impl fmt::Display for ContextServerId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
impl CodeLabelSpan {
/// Returns a [`CodeLabelSpan::CodeRange`].
pub fn code_range(range: impl Into<wit::Range>) -> Self {
@@ -135,6 +135,9 @@ world extension {
/// Returns the output from running the provided slash command.
export run-slash-command: func(command: slash-command, args: list<string>, worktree: option<borrow<worktree>>) -> result<slash-command-output, string>;
+ /// Returns the command used to start up a context server.
+ export context-server-command: func(context-server-id: string) -> result<command, string>;
+
/// Returns a list of packages as suggestions to be included in the `/docs`
/// search results.
///
@@ -145,6 +145,14 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
) {
}
+ fn register_context_server(
+ &self,
+ _id: Arc<str>,
+ _extension: WasmExtension,
+ _host: Arc<WasmHost>,
+ ) {
+ }
+
fn register_docs_provider(
&self,
_extension: WasmExtension,
@@ -1267,6 +1275,14 @@ impl ExtensionStore {
);
}
+ for (id, _context_server_entry) in &manifest.context_servers {
+ this.registration_hooks.register_context_server(
+ id.clone(),
+ wasm_extension.clone(),
+ this.wasm_host.clone(),
+ );
+ }
+
for (provider_id, _provider) in &manifest.indexed_docs_providers {
this.registration_hooks.register_docs_provider(
wasm_extension.clone(),
@@ -384,6 +384,24 @@ impl Extension {
}
}
+ pub async fn call_context_server_command(
+ &self,
+ store: &mut Store<WasmState>,
+ context_server_id: Arc<str>,
+ ) -> Result<Result<Command, String>> {
+ match self {
+ Extension::V020(ext) => {
+ ext.call_context_server_command(store, &context_server_id)
+ .await
+ }
+ Extension::V001(_) | Extension::V004(_) | Extension::V006(_) | Extension::V010(_) => {
+ Err(anyhow!(
+ "`context_server_command` not available prior to v0.2.0"
+ ))
+ }
+ }
+ }
+
pub async fn call_suggest_docs_packages(
&self,
store: &mut Store<WasmState>,
@@ -20,6 +20,7 @@ assistant_slash_command.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
+context_servers.workspace = true
db.workspace = true
editor.workspace = true
extension_host.workspace = true
@@ -0,0 +1,80 @@
+use std::pin::Pin;
+use std::sync::Arc;
+
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use context_servers::manager::{NativeContextServer, ServerConfig};
+use context_servers::protocol::InitializedContextServerProtocol;
+use context_servers::ContextServer;
+use extension_host::wasm_host::{WasmExtension, WasmHost};
+use futures::{Future, FutureExt};
+use gpui::AsyncAppContext;
+
+pub struct ExtensionContextServer {
+ #[allow(unused)]
+ pub(crate) extension: WasmExtension,
+ #[allow(unused)]
+ pub(crate) host: Arc<WasmHost>,
+ id: Arc<str>,
+ context_server: Arc<NativeContextServer>,
+}
+
+impl ExtensionContextServer {
+ pub async fn new(extension: WasmExtension, host: Arc<WasmHost>, id: Arc<str>) -> Result<Self> {
+ let command = extension
+ .call({
+ let id = id.clone();
+ |extension, store| {
+ async move {
+ let command = extension
+ .call_context_server_command(store, id.clone())
+ .await?
+ .map_err(|e| anyhow!("{}", e))?;
+ anyhow::Ok(command)
+ }
+ .boxed()
+ }
+ })
+ .await?;
+
+ let config = Arc::new(ServerConfig {
+ id: id.to_string(),
+ executable: command.command,
+ args: command.args,
+ env: Some(command.env.into_iter().collect()),
+ });
+
+ anyhow::Ok(Self {
+ extension,
+ host,
+ id,
+ context_server: Arc::new(NativeContextServer::new(config)),
+ })
+ }
+}
+
+#[async_trait(?Send)]
+impl ContextServer for ExtensionContextServer {
+ fn id(&self) -> Arc<str> {
+ self.id.clone()
+ }
+
+ fn config(&self) -> Arc<ServerConfig> {
+ self.context_server.config()
+ }
+
+ fn client(&self) -> Option<Arc<InitializedContextServerProtocol>> {
+ self.context_server.client()
+ }
+
+ fn start<'a>(
+ self: Arc<Self>,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
+ self.context_server.clone().start(cx)
+ }
+
+ fn stop(&self) -> Result<()> {
+ self.context_server.stop()
+ }
+}
@@ -2,6 +2,7 @@ use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use assistant_slash_command::SlashCommandRegistry;
+use context_servers::ContextServerFactoryRegistry;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs;
use gpui::{AppContext, BackgroundExecutor, Task};
@@ -11,6 +12,7 @@ use snippet_provider::SnippetRegistry;
use theme::{ThemeRegistry, ThemeSettings};
use ui::SharedString;
+use crate::extension_context_server::ExtensionContextServer;
use crate::{extension_indexed_docs_provider, extension_slash_command::ExtensionSlashCommand};
pub struct ConcreteExtensionRegistrationHooks {
@@ -19,6 +21,7 @@ pub struct ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
+ context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
executor: BackgroundExecutor,
}
@@ -29,6 +32,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
+ context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
cx: &AppContext,
) -> Arc<dyn extension_host::ExtensionRegistrationHooks> {
Arc::new(Self {
@@ -37,6 +41,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry,
snippet_registry,
language_registry,
+ context_server_factory_registry,
executor: cx.background_executor().clone(),
})
}
@@ -69,6 +74,31 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
)
}
+ fn register_context_server(
+ &self,
+ id: Arc<str>,
+ extension: wasm_host::WasmExtension,
+ host: Arc<wasm_host::WasmHost>,
+ ) {
+ self.context_server_factory_registry
+ .register_server_factory(
+ id.clone(),
+ Arc::new({
+ move |cx| {
+ let id = id.clone();
+ let extension = extension.clone();
+ let host = host.clone();
+ cx.spawn(|_cx| async move {
+ let context_server =
+ ExtensionContextServer::new(extension, host, id).await?;
+
+ anyhow::Ok(Arc::new(context_server) as _)
+ })
+ }
+ }),
+ );
+ }
+
fn register_docs_provider(
&self,
extension: wasm_host::WasmExtension,
@@ -1,6 +1,7 @@
use assistant_slash_command::SlashCommandRegistry;
use async_compression::futures::bufread::GzipEncoder;
use collections::BTreeMap;
+use context_servers::ContextServerFactoryRegistry;
use extension_host::ExtensionSettings;
use extension_host::SchemaVersion;
use extension_host::{
@@ -161,6 +162,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
.into_iter()
.collect(),
language_servers: BTreeMap::default(),
+ context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@@ -187,6 +189,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
+ context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@@ -264,6 +267,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
+ let context_server_factory_registry = ContextServerFactoryRegistry::new();
let node_runtime = NodeRuntime::unavailable();
let store = cx.new_model(|cx| {
@@ -273,6 +277,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
indexed_docs_registry.clone(),
snippet_registry.clone(),
language_registry.clone(),
+ context_server_factory_registry.clone(),
cx,
);
@@ -356,6 +361,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
+ context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@@ -406,6 +412,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
indexed_docs_registry,
snippet_registry,
language_registry.clone(),
+ context_server_factory_registry.clone(),
cx,
);
@@ -500,6 +507,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
+ let context_server_factory_registry = ContextServerFactoryRegistry::new();
let node_runtime = NodeRuntime::unavailable();
let mut status_updates = language_registry.language_server_binary_statuses();
@@ -596,6 +604,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
indexed_docs_registry,
snippet_registry,
language_registry.clone(),
+ context_server_factory_registry.clone(),
cx,
);
ExtensionStore::new(
@@ -1,4 +1,5 @@
mod components;
+mod extension_context_server;
mod extension_indexed_docs_provider;
mod extension_registration_hooks;
mod extension_slash_command;
@@ -35,6 +35,7 @@ collab_ui.workspace = true
collections.workspace = true
command_palette.workspace = true
command_palette_hooks.workspace = true
+context_servers.workspace = true
copilot.workspace = true
db.workspace = true
diagnostics.workspace = true
@@ -13,6 +13,7 @@ use clap::{command, Parser};
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{parse_zed_link, Client, ProxySettings, UserStore};
use collab_ui::channel_view::ChannelView;
+use context_servers::ContextServerFactoryRegistry;
use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE};
use editor::Editor;
use env_logger::Builder;
@@ -411,6 +412,7 @@ fn main() {
IndexedDocsRegistry::global(cx),
SnippetRegistry::global(cx),
app_state.languages.clone(),
+ ContextServerFactoryRegistry::global(cx),
cx,
);
extension_host::init(