wasm_host.rs

  1pub mod wit;
  2
  3use crate::{ExtensionManifest, ExtensionRegistrationHooks};
  4use anyhow::{anyhow, bail, Context as _, Result};
  5use async_trait::async_trait;
  6use extension::KeyValueStoreDelegate;
  7use fs::{normalize_path, Fs};
  8use futures::future::LocalBoxFuture;
  9use futures::{
 10    channel::{
 11        mpsc::{self, UnboundedSender},
 12        oneshot,
 13    },
 14    future::BoxFuture,
 15    Future, FutureExt, StreamExt as _,
 16};
 17use gpui::{AppContext, AsyncAppContext, BackgroundExecutor, Task};
 18use http_client::HttpClient;
 19use node_runtime::NodeRuntime;
 20use release_channel::ReleaseChannel;
 21use semantic_version::SemanticVersion;
 22use std::{
 23    path::{Path, PathBuf},
 24    sync::{Arc, OnceLock},
 25};
 26use wasmtime::{
 27    component::{Component, ResourceTable},
 28    Engine, Store,
 29};
 30use wasmtime_wasi::{self as wasi, WasiView};
 31use wit::Extension;
 32pub use wit::{ExtensionProject, SlashCommand};
 33
 34pub struct WasmHost {
 35    engine: Engine,
 36    release_channel: ReleaseChannel,
 37    http_client: Arc<dyn HttpClient>,
 38    node_runtime: NodeRuntime,
 39    pub registration_hooks: Arc<dyn ExtensionRegistrationHooks>,
 40    fs: Arc<dyn Fs>,
 41    pub work_dir: PathBuf,
 42    _main_thread_message_task: Task<()>,
 43    main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
 44}
 45
 46#[derive(Clone)]
 47pub struct WasmExtension {
 48    tx: UnboundedSender<ExtensionCall>,
 49    pub manifest: Arc<ExtensionManifest>,
 50    pub work_dir: Arc<Path>,
 51    #[allow(unused)]
 52    pub zed_api_version: SemanticVersion,
 53}
 54
 55#[async_trait]
 56impl extension::Extension for WasmExtension {
 57    fn manifest(&self) -> Arc<ExtensionManifest> {
 58        self.manifest.clone()
 59    }
 60
 61    fn work_dir(&self) -> Arc<Path> {
 62        self.work_dir.clone()
 63    }
 64
 65    async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
 66        self.call(|extension, store| {
 67            async move {
 68                let packages = extension
 69                    .call_suggest_docs_packages(store, provider.as_ref())
 70                    .await?
 71                    .map_err(|err| anyhow!("{err:?}"))?;
 72
 73                Ok(packages)
 74            }
 75            .boxed()
 76        })
 77        .await
 78    }
 79
 80    async fn index_docs(
 81        &self,
 82        provider: Arc<str>,
 83        package_name: Arc<str>,
 84        kv_store: Arc<dyn KeyValueStoreDelegate>,
 85    ) -> Result<()> {
 86        self.call(|extension, store| {
 87            async move {
 88                let kv_store_resource = store.data_mut().table().push(kv_store)?;
 89                extension
 90                    .call_index_docs(
 91                        store,
 92                        provider.as_ref(),
 93                        package_name.as_ref(),
 94                        kv_store_resource,
 95                    )
 96                    .await?
 97                    .map_err(|err| anyhow!("{err:?}"))?;
 98
 99                anyhow::Ok(())
100            }
101            .boxed()
102        })
103        .await
104    }
105}
106
107pub struct WasmState {
108    manifest: Arc<ExtensionManifest>,
109    pub table: ResourceTable,
110    ctx: wasi::WasiCtx,
111    pub host: Arc<WasmHost>,
112}
113
114type MainThreadCall =
115    Box<dyn Send + for<'a> FnOnce(&'a mut AsyncAppContext) -> LocalBoxFuture<'a, ()>>;
116
117type ExtensionCall = Box<
118    dyn Send + for<'a> FnOnce(&'a mut Extension, &'a mut Store<WasmState>) -> BoxFuture<'a, ()>,
119>;
120
121fn wasm_engine() -> wasmtime::Engine {
122    static WASM_ENGINE: OnceLock<wasmtime::Engine> = OnceLock::new();
123
124    WASM_ENGINE
125        .get_or_init(|| {
126            let mut config = wasmtime::Config::new();
127            config.wasm_component_model(true);
128            config.async_support(true);
129            wasmtime::Engine::new(&config).unwrap()
130        })
131        .clone()
132}
133
134impl WasmHost {
135    pub fn new(
136        fs: Arc<dyn Fs>,
137        http_client: Arc<dyn HttpClient>,
138        node_runtime: NodeRuntime,
139        registration_hooks: Arc<dyn ExtensionRegistrationHooks>,
140        work_dir: PathBuf,
141        cx: &mut AppContext,
142    ) -> Arc<Self> {
143        let (tx, mut rx) = mpsc::unbounded::<MainThreadCall>();
144        let task = cx.spawn(|mut cx| async move {
145            while let Some(message) = rx.next().await {
146                message(&mut cx).await;
147            }
148        });
149        Arc::new(Self {
150            engine: wasm_engine(),
151            fs,
152            work_dir,
153            http_client,
154            node_runtime,
155            registration_hooks,
156            release_channel: ReleaseChannel::global(cx),
157            _main_thread_message_task: task,
158            main_thread_message_tx: tx,
159        })
160    }
161
162    pub fn load_extension(
163        self: &Arc<Self>,
164        wasm_bytes: Vec<u8>,
165        manifest: &Arc<ExtensionManifest>,
166        executor: BackgroundExecutor,
167    ) -> Task<Result<WasmExtension>> {
168        let this = self.clone();
169        let manifest = manifest.clone();
170        executor.clone().spawn(async move {
171            let zed_api_version = parse_wasm_extension_version(&manifest.id, &wasm_bytes)?;
172
173            let component = Component::from_binary(&this.engine, &wasm_bytes)
174                .context("failed to compile wasm component")?;
175
176            let mut store = wasmtime::Store::new(
177                &this.engine,
178                WasmState {
179                    ctx: this.build_wasi_ctx(&manifest).await?,
180                    manifest: manifest.clone(),
181                    table: ResourceTable::new(),
182                    host: this.clone(),
183                },
184            );
185
186            let mut extension = Extension::instantiate_async(
187                &mut store,
188                this.release_channel,
189                zed_api_version,
190                &component,
191            )
192            .await?;
193
194            extension
195                .call_init_extension(&mut store)
196                .await
197                .context("failed to initialize wasm extension")?;
198
199            let (tx, mut rx) = mpsc::unbounded::<ExtensionCall>();
200            executor
201                .spawn(async move {
202                    while let Some(call) = rx.next().await {
203                        (call)(&mut extension, &mut store).await;
204                    }
205                })
206                .detach();
207
208            Ok(WasmExtension {
209                manifest: manifest.clone(),
210                work_dir: this.work_dir.clone().into(),
211                tx,
212                zed_api_version,
213            })
214        })
215    }
216
217    async fn build_wasi_ctx(&self, manifest: &Arc<ExtensionManifest>) -> Result<wasi::WasiCtx> {
218        let extension_work_dir = self.work_dir.join(manifest.id.as_ref());
219        self.fs
220            .create_dir(&extension_work_dir)
221            .await
222            .context("failed to create extension work dir")?;
223
224        let file_perms = wasi::FilePerms::all();
225        let dir_perms = wasi::DirPerms::all();
226
227        Ok(wasi::WasiCtxBuilder::new()
228            .inherit_stdio()
229            .preopened_dir(&extension_work_dir, ".", dir_perms, file_perms)?
230            .preopened_dir(
231                &extension_work_dir,
232                extension_work_dir.to_string_lossy(),
233                dir_perms,
234                file_perms,
235            )?
236            .env("PWD", extension_work_dir.to_string_lossy())
237            .env("RUST_BACKTRACE", "full")
238            .build())
239    }
240
241    pub fn path_from_extension(&self, id: &Arc<str>, path: &Path) -> PathBuf {
242        let extension_work_dir = self.work_dir.join(id.as_ref());
243        normalize_path(&extension_work_dir.join(path))
244    }
245
246    pub fn writeable_path_from_extension(&self, id: &Arc<str>, path: &Path) -> Result<PathBuf> {
247        let extension_work_dir = self.work_dir.join(id.as_ref());
248        let path = normalize_path(&extension_work_dir.join(path));
249        if path.starts_with(&extension_work_dir) {
250            Ok(path)
251        } else {
252            Err(anyhow!("cannot write to path {}", path.display()))
253        }
254    }
255}
256
257pub fn parse_wasm_extension_version(
258    extension_id: &str,
259    wasm_bytes: &[u8],
260) -> Result<SemanticVersion> {
261    let mut version = None;
262
263    for part in wasmparser::Parser::new(0).parse_all(wasm_bytes) {
264        if let wasmparser::Payload::CustomSection(s) =
265            part.context("error parsing wasm extension")?
266        {
267            if s.name() == "zed:api-version" {
268                version = parse_wasm_extension_version_custom_section(s.data());
269                if version.is_none() {
270                    bail!(
271                        "extension {} has invalid zed:api-version section: {:?}",
272                        extension_id,
273                        s.data()
274                    );
275                }
276            }
277        }
278    }
279
280    // The reason we wait until we're done parsing all of the Wasm bytes to return the version
281    // is to work around a panic that can happen inside of Wasmtime when the bytes are invalid.
282    //
283    // By parsing the entirety of the Wasm bytes before we return, we're able to detect this problem
284    // earlier as an `Err` rather than as a panic.
285    version.ok_or_else(|| anyhow!("extension {} has no zed:api-version section", extension_id))
286}
287
288fn parse_wasm_extension_version_custom_section(data: &[u8]) -> Option<SemanticVersion> {
289    if data.len() == 6 {
290        Some(SemanticVersion::new(
291            u16::from_be_bytes([data[0], data[1]]) as _,
292            u16::from_be_bytes([data[2], data[3]]) as _,
293            u16::from_be_bytes([data[4], data[5]]) as _,
294        ))
295    } else {
296        None
297    }
298}
299
300impl WasmExtension {
301    pub async fn load(
302        extension_dir: PathBuf,
303        manifest: &Arc<ExtensionManifest>,
304        wasm_host: Arc<WasmHost>,
305        cx: &AsyncAppContext,
306    ) -> Result<Self> {
307        let path = extension_dir.join("extension.wasm");
308
309        let mut wasm_file = wasm_host
310            .fs
311            .open_sync(&path)
312            .await
313            .context("failed to open wasm file")?;
314
315        let mut wasm_bytes = Vec::new();
316        wasm_file
317            .read_to_end(&mut wasm_bytes)
318            .context("failed to read wasm")?;
319
320        wasm_host
321            .load_extension(wasm_bytes, manifest, cx.background_executor().clone())
322            .await
323            .with_context(|| format!("failed to load wasm extension {}", manifest.id))
324    }
325
326    pub async fn call<T, Fn>(&self, f: Fn) -> T
327    where
328        T: 'static + Send,
329        Fn: 'static
330            + Send
331            + for<'a> FnOnce(&'a mut Extension, &'a mut Store<WasmState>) -> BoxFuture<'a, T>,
332    {
333        let (return_tx, return_rx) = oneshot::channel();
334        self.tx
335            .clone()
336            .unbounded_send(Box::new(move |extension, store| {
337                async {
338                    let result = f(extension, store).await;
339                    return_tx.send(result).ok();
340                }
341                .boxed()
342            }))
343            .expect("wasm extension channel should not be closed yet");
344        return_rx.await.expect("wasm extension channel")
345    }
346}
347
348impl WasmState {
349    fn on_main_thread<T, Fn>(&self, f: Fn) -> impl 'static + Future<Output = T>
350    where
351        T: 'static + Send,
352        Fn: 'static + Send + for<'a> FnOnce(&'a mut AsyncAppContext) -> LocalBoxFuture<'a, T>,
353    {
354        let (return_tx, return_rx) = oneshot::channel();
355        self.host
356            .main_thread_message_tx
357            .clone()
358            .unbounded_send(Box::new(move |cx| {
359                async {
360                    let result = f(cx).await;
361                    return_tx.send(result).ok();
362                }
363                .boxed_local()
364            }))
365            .expect("main thread message channel should not be closed yet");
366        async move { return_rx.await.expect("main thread message channel") }
367    }
368
369    fn work_dir(&self) -> PathBuf {
370        self.host.work_dir.join(self.manifest.id.as_ref())
371    }
372}
373
374impl wasi::WasiView for WasmState {
375    fn table(&mut self) -> &mut ResourceTable {
376        &mut self.table
377    }
378
379    fn ctx(&mut self) -> &mut wasi::WasiCtx {
380        &mut self.ctx
381    }
382}