diff --git a/Cargo.lock b/Cargo.lock index 48db1977efa9772c1d253e9382ef788664056b7a..eb5df527243bf130f7bd11735f767ed51807bc6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5883,6 +5883,7 @@ dependencies = [ "fs", "futures 0.3.31", "gpui", + "gpui_tokio", "http_client", "language", "language_extension", diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index 42dcdd3a6f3fd7da3d40bc4cc5437ffdfcd688c5..16cbd9ac0c0ef938322f2b57789c7542549a570a 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -27,6 +27,7 @@ extension.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true +gpui_tokio.workspace = true http_client.workspace = true language.workspace = true log.workspace = true diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index 6f0897af6edbb38acef305ff03b76569a741aca5..309e089758eab8bed1139e2d813bc99b1febb594 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -19,6 +19,7 @@ use util::test::TempTree; fn extension_benchmarks(c: &mut Criterion) { let cx = init(); + cx.update(gpui_tokio::init); let mut group = c.benchmark_group("load"); @@ -37,7 +38,7 @@ fn extension_benchmarks(c: &mut Criterion) { |wasm_bytes| { let _extension = cx .executor() - .block(wasm_host.load_extension(wasm_bytes, &manifest, cx.executor())) + .block(wasm_host.load_extension(wasm_bytes, &manifest, &cx.to_async())) .unwrap(); }, BatchSize::SmallInput, diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 855077bcf87c58fb8e751d6477921d7e8bba8ad9..509edc6845c6e99745a4b94944cf5f2b68ff9b93 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -868,5 +868,6 @@ fn init_test(cx: &mut TestAppContext) { Project::init_settings(cx); ExtensionSettings::register(cx); language::init(cx); + gpui_tokio::init(cx); }); } diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index f77258e8957fa1be7579b931de82fd633a0f6ae4..22d11732a743d56e651ce71279fe4d276f269640 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -591,11 +591,12 @@ impl WasmHost { self: &Arc, wasm_bytes: Vec, manifest: &Arc, - executor: BackgroundExecutor, + cx: &AsyncApp, ) -> Task> { let this = self.clone(); let manifest = manifest.clone(); - executor.clone().spawn(async move { + let executor = cx.background_executor().clone(); + let load_extension_task = async move { let zed_api_version = parse_wasm_extension_version(&manifest.id, &wasm_bytes)?; let component = Component::from_binary(&this.engine, &wasm_bytes) @@ -632,20 +633,29 @@ impl WasmHost { .context("failed to initialize wasm extension")?; let (tx, mut rx) = mpsc::unbounded::(); - executor - .spawn(async move { - while let Some(call) = rx.next().await { - (call)(&mut extension, &mut store).await; - } - }) - .detach(); + let extension_task = async move { + while let Some(call) = rx.next().await { + (call)(&mut extension, &mut store).await; + } + }; - Ok(WasmExtension { - manifest: manifest.clone(), - work_dir: this.work_dir.join(manifest.id.as_ref()).into(), - tx, - zed_api_version, - }) + anyhow::Ok(( + extension_task, + WasmExtension { + manifest: manifest.clone(), + work_dir: this.work_dir.join(manifest.id.as_ref()).into(), + tx, + zed_api_version, + }, + )) + }; + cx.spawn(async move |cx| { + let (extension_task, extension) = load_extension_task.await?; + // we need to run run the task in an extension context as wasmtime_wasi may + // call into tokio, accessing its runtime handle + gpui_tokio::Tokio::spawn(cx, extension_task)?.detach(); + + Ok(extension) }) } @@ -747,7 +757,7 @@ impl WasmExtension { .context("failed to read wasm")?; wasm_host - .load_extension(wasm_bytes, manifest, cx.background_executor().clone()) + .load_extension(wasm_bytes, manifest, cx) .await .with_context(|| format!("failed to load wasm extension {}", manifest.id)) }