wasm_host.rs

  1use crate::ExtensionManifest;
  2use anyhow::{anyhow, bail, Context as _, Result};
  3use async_compression::futures::bufread::GzipDecoder;
  4use async_tar::Archive;
  5use async_trait::async_trait;
  6use fs::Fs;
  7use futures::{
  8    channel::{mpsc::UnboundedSender, oneshot},
  9    future::BoxFuture,
 10    io::BufReader,
 11    Future, FutureExt, StreamExt as _,
 12};
 13use gpui::BackgroundExecutor;
 14use language::{LanguageRegistry, LanguageServerBinaryStatus, LspAdapterDelegate};
 15use node_runtime::NodeRuntime;
 16use std::{
 17    path::PathBuf,
 18    sync::{Arc, OnceLock},
 19};
 20use util::{http::HttpClient, SemanticVersion};
 21use wasmtime::{
 22    component::{Component, Linker, Resource, ResourceTable},
 23    Engine, Store,
 24};
 25use wasmtime_wasi::preview2::{command as wasi_command, WasiCtx, WasiCtxBuilder, WasiView};
 26
 27pub mod wit {
 28    wasmtime::component::bindgen!({
 29        async: true,
 30        path: "../extension_api/wit",
 31        with: {
 32             "worktree": super::ExtensionWorktree,
 33        },
 34    });
 35}
 36
 37pub type ExtensionWorktree = Arc<dyn LspAdapterDelegate>;
 38
 39pub(crate) struct WasmHost {
 40    engine: Engine,
 41    linker: Arc<wasmtime::component::Linker<WasmState>>,
 42    http_client: Arc<dyn HttpClient>,
 43    node_runtime: Arc<dyn NodeRuntime>,
 44    language_registry: Arc<LanguageRegistry>,
 45    fs: Arc<dyn Fs>,
 46    pub(crate) work_dir: PathBuf,
 47}
 48
 49#[derive(Clone)]
 50pub struct WasmExtension {
 51    tx: UnboundedSender<ExtensionCall>,
 52    #[allow(unused)]
 53    zed_api_version: SemanticVersion,
 54}
 55
 56pub(crate) struct WasmState {
 57    manifest: Arc<ExtensionManifest>,
 58    table: ResourceTable,
 59    ctx: WasiCtx,
 60    host: Arc<WasmHost>,
 61}
 62
 63type ExtensionCall = Box<
 64    dyn Send
 65        + for<'a> FnOnce(&'a mut wit::Extension, &'a mut Store<WasmState>) -> BoxFuture<'a, ()>,
 66>;
 67
 68static WASM_ENGINE: OnceLock<wasmtime::Engine> = OnceLock::new();
 69
 70impl WasmHost {
 71    pub fn new(
 72        fs: Arc<dyn Fs>,
 73        http_client: Arc<dyn HttpClient>,
 74        node_runtime: Arc<dyn NodeRuntime>,
 75        language_registry: Arc<LanguageRegistry>,
 76        work_dir: PathBuf,
 77    ) -> Arc<Self> {
 78        let engine = WASM_ENGINE
 79            .get_or_init(|| {
 80                let mut config = wasmtime::Config::new();
 81                config.wasm_component_model(true);
 82                config.async_support(true);
 83                wasmtime::Engine::new(&config).unwrap()
 84            })
 85            .clone();
 86        let mut linker = Linker::new(&engine);
 87        wasi_command::add_to_linker(&mut linker).unwrap();
 88        wit::Extension::add_to_linker(&mut linker, |state: &mut WasmState| state).unwrap();
 89        Arc::new(Self {
 90            engine,
 91            linker: Arc::new(linker),
 92            fs,
 93            work_dir,
 94            http_client,
 95            node_runtime,
 96            language_registry,
 97        })
 98    }
 99
100    pub fn load_extension(
101        self: &Arc<Self>,
102        wasm_bytes: Vec<u8>,
103        manifest: Arc<ExtensionManifest>,
104        executor: BackgroundExecutor,
105    ) -> impl 'static + Future<Output = Result<WasmExtension>> {
106        let this = self.clone();
107        async move {
108            let component = Component::from_binary(&this.engine, &wasm_bytes)
109                .context("failed to compile wasm component")?;
110
111            let mut zed_api_version = None;
112            for part in wasmparser::Parser::new(0).parse_all(&wasm_bytes) {
113                if let wasmparser::Payload::CustomSection(s) = part? {
114                    if s.name() == "zed:api-version" {
115                        if s.data().len() != 6 {
116                            bail!(
117                                "extension {} has invalid zed:api-version section: {:?}",
118                                manifest.id,
119                                s.data()
120                            );
121                        }
122
123                        let major = u16::from_be_bytes(s.data()[0..2].try_into().unwrap()) as _;
124                        let minor = u16::from_be_bytes(s.data()[2..4].try_into().unwrap()) as _;
125                        let patch = u16::from_be_bytes(s.data()[4..6].try_into().unwrap()) as _;
126                        zed_api_version = Some(SemanticVersion {
127                            major,
128                            minor,
129                            patch,
130                        })
131                    }
132                }
133            }
134
135            let Some(zed_api_version) = zed_api_version else {
136                bail!("extension {} has no zed:api-version section", manifest.id);
137            };
138
139            let mut store = wasmtime::Store::new(
140                &this.engine,
141                WasmState {
142                    manifest,
143                    table: ResourceTable::new(),
144                    ctx: WasiCtxBuilder::new()
145                        .inherit_stdio()
146                        .env("RUST_BACKTRACE", "1")
147                        .build(),
148                    host: this.clone(),
149                },
150            );
151            let (mut extension, instance) =
152                wit::Extension::instantiate_async(&mut store, &component, &this.linker)
153                    .await
154                    .context("failed to instantiate wasm component")?;
155            let (tx, mut rx) = futures::channel::mpsc::unbounded::<ExtensionCall>();
156            executor
157                .spawn(async move {
158                    extension.call_init_extension(&mut store).await.unwrap();
159
160                    let _instance = instance;
161                    while let Some(call) = rx.next().await {
162                        (call)(&mut extension, &mut store).await;
163                    }
164                })
165                .detach();
166            Ok(WasmExtension {
167                tx,
168                zed_api_version,
169            })
170        }
171    }
172}
173
174impl WasmExtension {
175    pub async fn call<T, Fn>(&self, f: Fn) -> T
176    where
177        T: 'static + Send,
178        Fn: 'static
179            + Send
180            + for<'a> FnOnce(&'a mut wit::Extension, &'a mut Store<WasmState>) -> BoxFuture<'a, T>,
181    {
182        let (return_tx, return_rx) = oneshot::channel();
183        self.tx
184            .clone()
185            .unbounded_send(Box::new(move |extension, store| {
186                async {
187                    let result = f(extension, store).await;
188                    return_tx.send(result).ok();
189                }
190                .boxed()
191            }))
192            .expect("wasm extension channel should not be closed yet");
193        return_rx.await.expect("wasm extension channel")
194    }
195}
196
197#[async_trait]
198impl wit::HostWorktree for WasmState {
199    async fn read_text_file(
200        &mut self,
201        delegate: Resource<Arc<dyn LspAdapterDelegate>>,
202        path: String,
203    ) -> wasmtime::Result<Result<String, String>> {
204        let delegate = self.table().get(&delegate)?;
205        Ok(delegate
206            .read_text_file(path.into())
207            .await
208            .map_err(|error| error.to_string()))
209    }
210
211    fn drop(&mut self, _worktree: Resource<wit::Worktree>) -> Result<()> {
212        // we only ever hand out borrows of worktrees
213        Ok(())
214    }
215}
216
217#[async_trait]
218impl wit::ExtensionImports for WasmState {
219    async fn npm_package_latest_version(
220        &mut self,
221        package_name: String,
222    ) -> wasmtime::Result<Result<String, String>> {
223        async fn inner(this: &mut WasmState, package_name: String) -> anyhow::Result<String> {
224            this.host
225                .node_runtime
226                .npm_package_latest_version(&package_name)
227                .await
228        }
229
230        Ok(inner(self, package_name)
231            .await
232            .map_err(|err| err.to_string()))
233    }
234
235    async fn latest_github_release(
236        &mut self,
237        repo: String,
238        options: wit::GithubReleaseOptions,
239    ) -> wasmtime::Result<Result<wit::GithubRelease, String>> {
240        async fn inner(
241            this: &mut WasmState,
242            repo: String,
243            options: wit::GithubReleaseOptions,
244        ) -> anyhow::Result<wit::GithubRelease> {
245            let release = util::github::latest_github_release(
246                &repo,
247                options.require_assets,
248                options.pre_release,
249                this.host.http_client.clone(),
250            )
251            .await?;
252            Ok(wit::GithubRelease {
253                version: release.tag_name,
254                assets: release
255                    .assets
256                    .into_iter()
257                    .map(|asset| wit::GithubReleaseAsset {
258                        name: asset.name,
259                        download_url: asset.browser_download_url,
260                    })
261                    .collect(),
262            })
263        }
264
265        Ok(inner(self, repo, options)
266            .await
267            .map_err(|err| err.to_string()))
268    }
269
270    async fn current_platform(&mut self) -> Result<(wit::Os, wit::Architecture)> {
271        Ok((
272            match std::env::consts::OS {
273                "macos" => wit::Os::Mac,
274                "linux" => wit::Os::Linux,
275                "windows" => wit::Os::Windows,
276                _ => panic!("unsupported os"),
277            },
278            match std::env::consts::ARCH {
279                "aarch64" => wit::Architecture::Aarch64,
280                "x86" => wit::Architecture::X86,
281                "x86_64" => wit::Architecture::X8664,
282                _ => panic!("unsupported architecture"),
283            },
284        ))
285    }
286
287    async fn set_language_server_installation_status(
288        &mut self,
289        server_name: String,
290        status: wit::LanguageServerInstallationStatus,
291    ) -> wasmtime::Result<()> {
292        let status = match status {
293            wit::LanguageServerInstallationStatus::CheckingForUpdate => {
294                LanguageServerBinaryStatus::CheckingForUpdate
295            }
296            wit::LanguageServerInstallationStatus::Downloading => {
297                LanguageServerBinaryStatus::Downloading
298            }
299            wit::LanguageServerInstallationStatus::Downloaded => {
300                LanguageServerBinaryStatus::Downloaded
301            }
302            wit::LanguageServerInstallationStatus::Cached => LanguageServerBinaryStatus::Cached,
303            wit::LanguageServerInstallationStatus::Failed(error) => {
304                LanguageServerBinaryStatus::Failed { error }
305            }
306        };
307
308        self.host
309            .language_registry
310            .update_lsp_status(language::LanguageServerName(server_name.into()), status);
311        Ok(())
312    }
313
314    async fn download_file(
315        &mut self,
316        url: String,
317        filename: String,
318        file_type: wit::DownloadedFileType,
319    ) -> wasmtime::Result<Result<(), String>> {
320        async fn inner(
321            this: &mut WasmState,
322            url: String,
323            filename: String,
324            file_type: wit::DownloadedFileType,
325        ) -> anyhow::Result<()> {
326            this.host.fs.create_dir(&this.host.work_dir).await?;
327            let container_dir = this.host.work_dir.join(this.manifest.id.as_ref());
328            let destination_path = container_dir.join(&filename);
329
330            let mut response = this
331                .host
332                .http_client
333                .get(&url, Default::default(), true)
334                .await
335                .map_err(|err| anyhow!("error downloading release: {}", err))?;
336
337            if !response.status().is_success() {
338                Err(anyhow!(
339                    "download failed with status {}",
340                    response.status().to_string()
341                ))?;
342            }
343            let body = BufReader::new(response.body_mut());
344
345            match file_type {
346                wit::DownloadedFileType::Uncompressed => {
347                    futures::pin_mut!(body);
348                    this.host
349                        .fs
350                        .create_file_with(&destination_path, body)
351                        .await?;
352                }
353                wit::DownloadedFileType::Gzip => {
354                    let body = GzipDecoder::new(body);
355                    futures::pin_mut!(body);
356                    this.host
357                        .fs
358                        .create_file_with(&destination_path, body)
359                        .await?;
360                }
361                wit::DownloadedFileType::GzipTar => {
362                    let body = GzipDecoder::new(body);
363                    futures::pin_mut!(body);
364                    this.host
365                        .fs
366                        .extract_tar_file(&destination_path, Archive::new(body))
367                        .await?;
368                }
369                wit::DownloadedFileType::Zip => {
370                    let zip_filename = format!("{filename}.zip");
371                    let mut zip_path = destination_path.clone();
372                    zip_path.set_file_name(zip_filename);
373                    futures::pin_mut!(body);
374                    this.host.fs.create_file_with(&zip_path, body).await?;
375
376                    let unzip_status = std::process::Command::new("unzip")
377                        .current_dir(&container_dir)
378                        .arg(&zip_path)
379                        .output()?
380                        .status;
381                    if !unzip_status.success() {
382                        Err(anyhow!("failed to unzip {filename} archive"))?;
383                    }
384                }
385            }
386
387            Ok(())
388        }
389
390        Ok(inner(self, url, filename, file_type)
391            .await
392            .map(|_| ())
393            .map_err(|err| err.to_string()))
394    }
395}
396
397impl WasiView for WasmState {
398    fn table(&mut self) -> &mut ResourceTable {
399        &mut self.table
400    }
401
402    fn ctx(&mut self) -> &mut WasiCtx {
403        &mut self.ctx
404    }
405}