Add async host functions

Isaac Clayton created

Change summary

crates/plugin_macros/src/lib.rs     |   7 -
crates/plugin_runtime/build.rs      |   6 +
crates/plugin_runtime/src/lib.rs    |  11 +-
crates/plugin_runtime/src/plugin.rs | 119 +++++++++++++++---------------
plugins/json_language/src/lib.rs    |  48 ++++++------
plugins/test_plugin/src/lib.rs      |   8 ++
6 files changed, 102 insertions(+), 97 deletions(-)

Detailed changes

crates/plugin_macros/src/lib.rs 🔗

@@ -115,11 +115,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
         })
         .unzip();
 
-    dbg!("hello");
-
     let body = TokenStream::from(quote! {
         {
-            // dbg!("executing imported function");
             // setup
             let data: (#( #tys ),*) = (#( #args ),*);
             let data = ::plugin::bincode::serialize(&data).unwrap();
@@ -137,12 +134,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
         }
     });
 
-    dbg!("hello2");
-
     let block = parse_macro_input!(body as Block);
 
-    dbg!("hello {:?}", &block);
-
     let inner_fn = ItemFn {
         attrs: fn_declare.attrs,
         vis: fn_declare.vis,

crates/plugin_runtime/build.rs 🔗

@@ -10,7 +10,7 @@ fn main() {
     let _ =
         std::fs::create_dir_all(base.join("bin")).expect("Could not make plugins bin directory");
 
-    std::process::Command::new("cargo")
+    let build_successful = std::process::Command::new("cargo")
         .args([
             "build",
             "--release",
@@ -20,7 +20,9 @@ fn main() {
             base.join("Cargo.toml").to_str().unwrap(),
         ])
         .status()
-        .expect("Could not build plugins");
+        .expect("Could not build plugins")
+        .success();
+    assert!(build_successful);
 
     let binaries = std::fs::read_dir(base.join("target/wasm32-wasi/release"))
         .expect("Could not find compiled plugins in target");

crates/plugin_runtime/src/lib.rs 🔗

@@ -18,12 +18,9 @@ mod tests {
             print: WasiFn<String, ()>,
             and_back: WasiFn<u32, u32>,
             imports: WasiFn<u32, u32>,
+            half_async: WasiFn<u32, u32>,
         }
 
-        // async fn half(a: u32) -> u32 {
-        //     a / 2
-        // }
-
         async {
             let mut runtime = PluginBuilder::new_with_default_ctx()
                 .unwrap()
@@ -35,8 +32,8 @@ mod tests {
                 .unwrap()
                 .host_function("import_swap", |(a, b): (u32, u32)| (b, a))
                 .unwrap()
-                // .host_function_async("import_half", half)
-                // .unwrap()
+                .host_function_async("import_half", |a: u32| async move { a / 2 })
+                .unwrap()
                 .init(include_bytes!("../../../plugins/bin/test_plugin.wasm"))
                 .await
                 .unwrap();
@@ -51,6 +48,7 @@ mod tests {
                 print: runtime.function("print").unwrap(),
                 and_back: runtime.function("and_back").unwrap(),
                 imports: runtime.function("imports").unwrap(),
+                half_async: runtime.function("half_async").unwrap(),
             };
 
             let unsorted = vec![1, 3, 4, 2, 5];
@@ -65,6 +63,7 @@ mod tests {
             assert_eq!(runtime.call(&plugin.print, "Hi!".into()).await.unwrap(), ());
             assert_eq!(runtime.call(&plugin.and_back, 1).await.unwrap(), 8);
             assert_eq!(runtime.call(&plugin.imports, 1).await.unwrap(), 8);
+            assert_eq!(runtime.call(&plugin.half_async, 4).await.unwrap(), 2);
 
             // dbg!("{}", runtime.call(&plugin.and_back, 1).await.unwrap());
         }

crates/plugin_runtime/src/plugin.rs 🔗

@@ -1,3 +1,5 @@
+use std::future::Future;
+use std::pin::Pin;
 use std::{fs::File, marker::PhantomData, path::Path};
 
 use anyhow::{anyhow, Error};
@@ -142,7 +144,7 @@ impl PluginBuilder {
     //         move |_: Caller<'_, WasiCtxAlloc>, _: u64| {
     //             // let function = &function;
     //             Box::new(async {
-    //                 let function = function;
+    //                 // let function = function;
     //                 // Call the Host-side function
     //                 let result: u64 = function(7).await;
     //                 Ok(result)
@@ -152,68 +154,69 @@ impl PluginBuilder {
     //     Ok(self)
     // }
 
-    // pub fn host_function_async<F, A, R>(mut self, name: &str, function: F) -> Result<Self, Error>
-    // where
-    //     F: Fn(A) -> Pin<Box<dyn Future<Output = R> + Send + 'static>> + Send + Sync + 'static,
-    //     A: DeserializeOwned + Send,
-    //     R: Serialize + Send + Sync,
-    // {
-    //     self.linker.func_wrap1_async(
-    //         "env",
-    //         &format!("__{}", name),
-    //         move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
-    //             let function = |args: Vec<u8>| {
-    //                 let args = args;
-    //                 let args: A = Wasi::deserialize_to_type(&args)?;
-    //                 Ok(async {
-    //                     let result = function(args);
-    //                     Wasi::serialize_to_bytes(result.await).map_err(|_| {
-    //                         Trap::new("Could not serialize value returned from function").into()
-    //                     })
-    //                 })
-    //             };
-
-    //             // TODO: use try block once avaliable
-    //             let result: Result<(WasiBuffer, Memory, _), Trap> = (|| {
-    //                 // grab a handle to the memory
-    //                 let mut plugin_memory = match caller.get_export("memory") {
-    //                     Some(Extern::Memory(mem)) => mem,
-    //                     _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
-    //                 };
+    pub fn host_function_async<F, A, R, Fut>(
+        mut self,
+        name: &str,
+        function: F,
+    ) -> Result<Self, Error>
+    where
+        F: Fn(A) -> Fut + Send + Sync + 'static,
+        Fut: Future<Output = R> + Send + 'static,
+        A: DeserializeOwned + Send + 'static,
+        R: Serialize + Send + Sync + 'static,
+    {
+        self.linker.func_wrap1_async(
+            "env",
+            &format!("__{}", name),
+            move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
+                // TODO: use try block once avaliable
+                let result: Result<(WasiBuffer, Memory, _), Trap> = (|| {
+                    // grab a handle to the memory
+                    let mut plugin_memory = match caller.get_export("memory") {
+                        Some(Extern::Memory(mem)) => mem,
+                        _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
+                    };
 
-    //                 let buffer = WasiBuffer::from_u64(packed_buffer);
+                    let buffer = WasiBuffer::from_u64(packed_buffer);
 
-    //                 // get the args passed from Guest
-    //                 let args = Wasi::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?;
+                    // get the args passed from Guest
+                    let args = Plugin::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?;
 
-    //                 // Call the Host-side function
-    //                 let result = function(args);
+                    let args: A = Plugin::deserialize_to_type(&args)?;
 
-    //                 Ok((buffer, plugin_memory, result))
-    //             })();
+                    // Call the Host-side function
+                    let result = function(args);
 
-    //             Box::new(async move {
-    //                 let (buffer, mut plugin_memory, thingo) = result?;
-    //                 let thingo: Result<_, Error> = thingo;
-    //                 let result: Result<Vec<u8>, Error> = thingo?.await;
-
-    //                 // Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?;
-
-    //                 // let buffer = Wasi::bytes_to_buffer(
-    //                 //     caller.data().alloc_buffer(),
-    //                 //     &mut plugin_memory,
-    //                 //     &mut caller,
-    //                 //     result,
-    //                 // )
-    //                 // .await?;
-
-    //                 // Ok(buffer.into_u64())
-    //                 Ok(27)
-    //             })
-    //         },
-    //     )?;
-    //     Ok(self)
-    // }
+                    Ok((buffer, plugin_memory, result))
+                })();
+
+                Box::new(async move {
+                    let (buffer, mut plugin_memory, future) = result?;
+
+                    let result: R = future.await;
+                    let result: Result<Vec<u8>, Error> = Plugin::serialize_to_bytes(result)
+                        .map_err(|_| {
+                            Trap::new("Could not serialize value returned from function").into()
+                        });
+                    let result = result?;
+
+                    Plugin::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer)
+                        .await?;
+
+                    let buffer = Plugin::bytes_to_buffer(
+                        caller.data().alloc_buffer(),
+                        &mut plugin_memory,
+                        &mut caller,
+                        result,
+                    )
+                    .await?;
+
+                    Ok(buffer.into_u64())
+                })
+            },
+        )?;
+        Ok(self)
+    }
 
     pub fn host_function<A, R>(
         mut self,

plugins/json_language/src/lib.rs 🔗

@@ -4,8 +4,8 @@ use serde_json::json;
 use std::fs;
 use std::path::PathBuf;
 
-// #[import]
-// fn command(string: &str) -> Option<String>;
+#[import]
+fn command(string: &str) -> Option<String>;
 
 // #[no_mangle]
 // // TODO: switch len from usize to u32?
@@ -28,29 +28,29 @@ use std::path::PathBuf;
 //     return new_buffer.leak_to_heap();
 // }
 
-extern "C" {
-    fn __command(buffer: u64) -> u64;
-}
+// extern "C" {
+//     fn __command(buffer: u64) -> u64;
+// }
 
-#[no_mangle]
-fn command(string: &str) -> Option<Vec<u8>> {
-    dbg!("executing command: {}", string);
-    // setup
-    let data = string;
-    let data = ::plugin::bincode::serialize(&data).unwrap();
-    let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
-
-    // operation
-    let new_buffer = unsafe { __command(buffer.into_u64()) };
-    let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
-    let new_data: Option<Vec<u8>> = match ::plugin::bincode::deserialize(&new_data) {
-        Ok(d) => d,
-        Err(e) => panic!("Data returned from function not deserializable."),
-    };
-
-    // teardown
-    return new_data;
-}
+// #[no_mangle]
+// fn command(string: &str) -> Option<Vec<u8>> {
+//     dbg!("executing command: {}", string);
+//     // setup
+//     let data = string;
+//     let data = ::plugin::bincode::serialize(&data).unwrap();
+//     let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
+
+//     // operation
+//     let new_buffer = unsafe { __command(buffer.into_u64()) };
+//     let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
+//     let new_data: Option<Vec<u8>> = match ::plugin::bincode::deserialize(&new_data) {
+//         Ok(d) => d,
+//         Err(e) => panic!("Data returned from function not deserializable."),
+//     };
+
+//     // teardown
+//     return new_data;
+// }
 
 // TODO: some sort of macro to generate ABI bindings
 // extern "C" {

plugins/test_plugin/src/lib.rs 🔗

@@ -61,3 +61,11 @@ pub fn imports(x: u32) -> u32 {
     assert_eq!(x, b);
     a + b // should be 7 + x
 }
+
+#[import]
+fn import_half(a: u32) -> u32;
+
+#[export]
+pub fn half_async(a: u32) -> u32 {
+    import_half(a)
+}