Use `Extension` trait when registering extension context servers (#21070)

Marshall Bowers created

This PR updates the extension context server registration to go through
the `Extension` trait for interacting with extensions rather than going
through the `WasmHost` directly.

Release Notes:

- N/A

Change summary

Cargo.lock                                               |  1 
crates/extension/src/extension.rs                        | 10 +
crates/extension/src/types.rs                            |  1 
crates/extension_host/src/extension_host.rs              |  4 
crates/extension_host/src/wasm_host.rs                   | 22 +++
crates/extension_host/src/wasm_host/wit/since_v0_2_0.rs  |  9 -
crates/extensions_ui/Cargo.toml                          |  1 
crates/extensions_ui/src/extension_registration_hooks.rs | 57 ++++-----
8 files changed, 59 insertions(+), 46 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4237,7 +4237,6 @@ dependencies = [
  "ui",
  "util",
  "vim_mode_setting",
- "wasmtime-wasi",
  "workspace",
  "zed_actions",
 ]

crates/extension/src/extension.rs 🔗

@@ -25,6 +25,10 @@ pub trait WorktreeDelegate: Send + Sync + 'static {
     async fn shell_env(&self) -> Vec<(String, String)>;
 }
 
+pub trait ProjectDelegate: Send + Sync + 'static {
+    fn worktree_ids(&self) -> Vec<u64>;
+}
+
 pub trait KeyValueStoreDelegate: Send + Sync + 'static {
     fn insert(&self, key: String, docs: String) -> Task<Result<()>>;
 }
@@ -87,6 +91,12 @@ pub trait Extension: Send + Sync + 'static {
         worktree: Option<Arc<dyn WorktreeDelegate>>,
     ) -> Result<SlashCommandOutput>;
 
+    async fn context_server_command(
+        &self,
+        context_server_id: Arc<str>,
+        project: Arc<dyn ProjectDelegate>,
+    ) -> Result<Command>;
+
     async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>>;
 
     async fn index_docs(

crates/extension/src/types.rs 🔗

@@ -10,6 +10,7 @@ pub use slash_command::*;
 pub type EnvVars = Vec<(String, String)>;
 
 /// A command.
+#[derive(Debug)]
 pub struct Command {
     /// The command to execute.
     pub command: String,

crates/extension_host/src/extension_host.rs 🔗

@@ -149,8 +149,8 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
 
     fn register_context_server(
         &self,
+        _extension: Arc<dyn Extension>,
         _id: Arc<str>,
-        _extension: WasmExtension,
         _cx: &mut AppContext,
     ) {
     }
@@ -1284,8 +1284,8 @@ impl ExtensionStore {
 
                     for (id, _context_server_entry) in &manifest.context_servers {
                         this.registration_hooks.register_context_server(
+                            extension.clone(),
                             id.clone(),
-                            wasm_extension.clone(),
                             cx,
                         );
                     }

crates/extension_host/src/wasm_host.rs 🔗

@@ -4,7 +4,7 @@ use crate::{ExtensionManifest, ExtensionRegistrationHooks};
 use anyhow::{anyhow, bail, Context as _, Result};
 use async_trait::async_trait;
 use extension::{
-    CodeLabel, Command, Completion, KeyValueStoreDelegate, SlashCommand,
+    CodeLabel, Command, Completion, KeyValueStoreDelegate, ProjectDelegate, SlashCommand,
     SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
 };
 use fs::{normalize_path, Fs};
@@ -34,7 +34,6 @@ use wasmtime::{
 };
 use wasmtime_wasi::{self as wasi, WasiView};
 use wit::Extension;
-pub use wit::ExtensionProject;
 
 pub struct WasmHost {
     engine: Engine,
@@ -238,6 +237,25 @@ impl extension::Extension for WasmExtension {
         .await
     }
 
+    async fn context_server_command(
+        &self,
+        context_server_id: Arc<str>,
+        project: Arc<dyn ProjectDelegate>,
+    ) -> Result<Command> {
+        self.call(|extension, store| {
+            async move {
+                let project_resource = store.data_mut().table().push(project)?;
+                let command = extension
+                    .call_context_server_command(store, context_server_id.clone(), project_resource)
+                    .await?
+                    .map_err(|err| anyhow!("{err}"))?;
+                anyhow::Ok(command.into())
+            }
+            .boxed()
+        })
+        .await
+    }
+
     async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
         self.call(|extension, store| {
             async move {

crates/extension_host/src/wasm_host/wit/since_v0_2_0.rs 🔗

@@ -8,7 +8,7 @@ use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
 use async_trait::async_trait;
 use context_servers::manager::ContextServerSettings;
-use extension::{KeyValueStoreDelegate, WorktreeDelegate};
+use extension::{KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate};
 use futures::{io::BufReader, FutureExt as _};
 use futures::{lock::Mutex, AsyncReadExt};
 use language::{language_settings::AllLanguageSettings, LanguageName, LanguageServerBinaryStatus};
@@ -44,13 +44,10 @@ mod settings {
 }
 
 pub type ExtensionWorktree = Arc<dyn WorktreeDelegate>;
+pub type ExtensionProject = Arc<dyn ProjectDelegate>;
 pub type ExtensionKeyValueStore = Arc<dyn KeyValueStoreDelegate>;
 pub type ExtensionHttpResponseStream = Arc<Mutex<::http_client::Response<AsyncBody>>>;
 
-pub struct ExtensionProject {
-    pub worktree_ids: Vec<u64>,
-}
-
 pub fn linker() -> &'static Linker<WasmState> {
     static LINKER: OnceLock<Linker<WasmState>> = OnceLock::new();
     LINKER.get_or_init(|| super::new_linker(Extension::add_to_linker))
@@ -273,7 +270,7 @@ impl HostProject for WasmState {
         project: Resource<ExtensionProject>,
     ) -> wasmtime::Result<Vec<u64>> {
         let project = self.table.get(&project)?;
-        Ok(project.worktree_ids.clone())
+        Ok(project.worktree_ids())
     }
 
     fn drop(&mut self, _project: Resource<Project>) -> Result<()> {

crates/extensions_ui/Cargo.toml 🔗

@@ -41,7 +41,6 @@ theme.workspace = true
 ui.workspace = true
 util.workspace = true
 vim_mode_setting.workspace = true
-wasmtime-wasi.workspace = true
 workspace.workspace = true
 zed_actions.workspace = true
 

crates/extensions_ui/src/extension_registration_hooks.rs 🔗

@@ -1,13 +1,11 @@
 use std::{path::PathBuf, sync::Arc};
 
-use anyhow::{anyhow, Result};
+use anyhow::Result;
 use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry};
 use context_servers::manager::ServerCommand;
 use context_servers::ContextServerFactoryRegistry;
-use db::smol::future::FutureExt as _;
-use extension::Extension;
-use extension_host::wasm_host::ExtensionProject;
-use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
+use extension::{Extension, ProjectDelegate};
+use extension_host::extension_lsp_adapter::ExtensionLspAdapter;
 use fs::Fs;
 use gpui::{AppContext, BackgroundExecutor, Model, Task};
 use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId};
@@ -16,7 +14,16 @@ use lsp::LanguageServerName;
 use snippet_provider::SnippetRegistry;
 use theme::{ThemeRegistry, ThemeSettings};
 use ui::SharedString;
-use wasmtime_wasi::WasiView as _;
+
+struct ExtensionProject {
+    worktree_ids: Vec<u64>,
+}
+
+impl ProjectDelegate for ExtensionProject {
+    fn worktree_ids(&self) -> Vec<u64> {
+        self.worktree_ids.clone()
+    }
+}
 
 pub struct ConcreteExtensionRegistrationHooks {
     slash_command_registry: Arc<SlashCommandRegistry>,
@@ -72,8 +79,8 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
 
     fn register_context_server(
         &self,
+        extension: Arc<dyn Extension>,
         id: Arc<str>,
-        extension: wasm_host::WasmExtension,
         cx: &mut AppContext,
     ) {
         self.context_server_factory_registry
@@ -84,42 +91,24 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
                         move |project, cx| {
                             log::info!(
                                 "loading command for context server {id} from extension {}",
-                                extension.manifest.id
+                                extension.manifest().id
                             );
 
                             let id = id.clone();
                             let extension = extension.clone();
                             cx.spawn(|mut cx| async move {
                                 let extension_project =
-                                    project.update(&mut cx, |project, cx| ExtensionProject {
-                                        worktree_ids: project
-                                            .visible_worktrees(cx)
-                                            .map(|worktree| worktree.read(cx).id().to_proto())
-                                            .collect(),
+                                    project.update(&mut cx, |project, cx| {
+                                        Arc::new(ExtensionProject {
+                                            worktree_ids: project
+                                                .visible_worktrees(cx)
+                                                .map(|worktree| worktree.read(cx).id().to_proto())
+                                                .collect(),
+                                        })
                                     })?;
 
                                 let command = extension
-                                    .call({
-                                        let id = id.clone();
-                                        |extension, store| {
-                                            async move {
-                                                let project = store
-                                                    .data_mut()
-                                                    .table()
-                                                    .push(extension_project)?;
-                                                let command = extension
-                                                    .call_context_server_command(
-                                                        store,
-                                                        id.clone(),
-                                                        project,
-                                                    )
-                                                    .await?
-                                                    .map_err(|e| anyhow!("{}", e))?;
-                                                anyhow::Ok(command)
-                                            }
-                                            .boxed()
-                                        }
-                                    })
+                                    .context_server_command(id.clone(), extension_project)
                                     .await?;
 
                                 log::info!("loaded command for context server {id}: {command:?}");