WIP: wrap async closures host-side

Isaac Clayton created

Change summary

crates/plugin_runtime/src/lib.rs  |   8 +
crates/plugin_runtime/src/wasi.rs | 167 ++++++++++++++++++++++++++++++++
2 files changed, 173 insertions(+), 2 deletions(-)

Detailed changes

crates/plugin_runtime/src/lib.rs 🔗

@@ -20,6 +20,12 @@ mod tests {
             imports: WasiFn<u32, u32>,
         }
 
+        async fn half(a: u32) -> u32 {
+            a / 2
+        }
+
+        let x = half;
+
         async {
             let mut runtime = WasiPluginBuilder::new_with_default_ctx()
                 .unwrap()
@@ -31,6 +37,8 @@ mod tests {
                 .unwrap()
                 .host_function("import_swap", |(a, b): (u32, u32)| (b, a))
                 .unwrap()
+                // .host_function_async("import_half", half)
+                // .unwrap()
                 .init(include_bytes!("../../../plugins/bin/test_plugin.wasm"))
                 .await
                 .unwrap();

crates/plugin_runtime/src/wasi.rs 🔗

@@ -79,11 +79,153 @@ impl WasiPluginBuilder {
         Self::new(wasi_ctx)
     }
 
-    pub fn host_function<A: DeserializeOwned + Send, R: Serialize + Send + Sync + Clone>(
+    // pub fn host_function_async<A: DeserializeOwned + Send, R: Serialize, F, Fut>(
+    //     mut self,
+    //     name: &str,
+    //     function: impl Fn(A) -> Pin<Box<dyn Future<Output = R> + Send + Sync>> + Sync + Send + 'static,
+    // ) -> Result<Self, Error>
+    // where
+    //     A: DeserializeOwned + Send,
+    //     R: Serialize + Send,
+    // {
+    //     self.linker.func_wrap1_async(
+    //         "env",
+    //         &format!("__{}", name),
+    //         move |caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
+    //             // let function = &function;
+    //             Box::new(async move {
+    //                 // grab a handle to the memory
+    //                 let mut plugin_memory = match caller.get_export("memory") {
+    //                     Some(Extern::Memory(mem)) => mem,
+    //                     _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
+    //                 };
+
+    //                 let buffer = WasiBuffer::from_u64(packed_buffer);
+
+    //                 // get the args passed from Guest
+    //                 let args = Wasi::buffer_to_type(&mut plugin_memory, &mut caller, &buffer)?;
+
+    //                 // Call the Host-side function
+    //                 let result: R = function(args).await;
+
+    //                 // Serialize the result back to guest
+    //                 let result = Wasi::serialize_to_bytes(result).map_err(|_| {
+    //                     Trap::new("Could not serialize value returned from function")
+    //                 })?;
+
+    //                 // Ok((buffer, plugin_memory, result))
+    //                 Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?;
+
+    //                 let buffer = Wasi::bytes_to_buffer(
+    //                     caller.data().alloc_buffer(),
+    //                     &mut plugin_memory,
+    //                     &mut caller,
+    //                     result,
+    //                 )
+    //                 .await?;
+
+    //                 Ok(buffer.into_u64())
+    //             })
+    //         },
+    //     )?;
+    //     Ok(self)
+    // }
+
+    // pub fn host_function_async<F>(mut self, name: &str, function: F) -> Result<Self, Error>
+    // where
+    //     F: Fn(u64) -> Pin<Box<dyn Future<Output = u64> + Send + Sync + 'static>>
+    //         + Send
+    //         + Sync
+    //         + 'static,
+    // {
+    //     self.linker.func_wrap1_async(
+    //         "env",
+    //         &format!("__{}", name),
+    //         move |_: Caller<'_, WasiCtxAlloc>, _: u64| {
+    //             // let function = &function;
+    //             Box::new(async {
+    //                 let function = function;
+    //                 // Call the Host-side function
+    //                 let result: u64 = function(7).await;
+    //                 Ok(result)
+    //             })
+    //         },
+    //     )?;
+    //     Ok(self)
+    // }
+
+    // pub fn host_function_async<F, A, R>(mut self, name: &str, function: F) -> Result<Self, Error>
+    // where
+    //     F: Fn(A) -> Pin<Box<dyn Future<Output = R> + Send + 'static>> + Send + Sync + 'static,
+    //     A: DeserializeOwned + Send,
+    //     R: Serialize + Send + Sync,
+    // {
+    //     self.linker.func_wrap1_async(
+    //         "env",
+    //         &format!("__{}", name),
+    //         move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
+    //             let function = |args: Vec<u8>| {
+    //                 let args = args;
+    //                 let args: A = Wasi::deserialize_to_type(&args)?;
+    //                 Ok(async {
+    //                     let result = function(args);
+    //                     Wasi::serialize_to_bytes(result.await).map_err(|_| {
+    //                         Trap::new("Could not serialize value returned from function").into()
+    //                     })
+    //                 })
+    //             };
+
+    //             // TODO: use try block once avaliable
+    //             let result: Result<(WasiBuffer, Memory, _), Trap> = (|| {
+    //                 // grab a handle to the memory
+    //                 let mut plugin_memory = match caller.get_export("memory") {
+    //                     Some(Extern::Memory(mem)) => mem,
+    //                     _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
+    //                 };
+
+    //                 let buffer = WasiBuffer::from_u64(packed_buffer);
+
+    //                 // get the args passed from Guest
+    //                 let args = Wasi::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?;
+
+    //                 // Call the Host-side function
+    //                 let result = function(args);
+
+    //                 Ok((buffer, plugin_memory, result))
+    //             })();
+
+    //             Box::new(async move {
+    //                 let (buffer, mut plugin_memory, thingo) = result?;
+    //                 let thingo: Result<_, Error> = thingo;
+    //                 let result: Result<Vec<u8>, Error> = thingo?.await;
+
+    //                 // Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?;
+
+    //                 // let buffer = Wasi::bytes_to_buffer(
+    //                 //     caller.data().alloc_buffer(),
+    //                 //     &mut plugin_memory,
+    //                 //     &mut caller,
+    //                 //     result,
+    //                 // )
+    //                 // .await?;
+
+    //                 // Ok(buffer.into_u64())
+    //                 Ok(27)
+    //             })
+    //         },
+    //     )?;
+    //     Ok(self)
+    // }
+
+    pub fn host_function<A, R>(
         mut self,
         name: &str,
         function: impl Fn(A) -> R + Send + Sync + 'static,
-    ) -> Result<Self, Error> {
+    ) -> Result<Self, Error>
+    where
+        A: DeserializeOwned + Send,
+        R: Serialize + Send + Sync,
+    {
         self.linker.func_wrap1_async(
             "env",
             &format!("__{}", name),
@@ -335,6 +477,12 @@ impl Wasi {
         Ok(bytes)
     }
 
+    fn deserialize_to_type<R: DeserializeOwned>(bytes: &[u8]) -> Result<R, Error> {
+        // serialize the argument using bincode
+        let bytes = bincode::deserialize(bytes)?;
+        Ok(bytes)
+    }
+
     // fn deserialize<R: DeserializeOwned>(
     //     plugin_memory: &mut Memory,
     //     mut store: impl AsContextMut<Data = WasiCtxAlloc>,
@@ -382,6 +530,21 @@ impl Wasi {
         Ok(result)
     }
 
+    /// Takes a `(ptr, len)` pair and returns the corresponding deserialized buffer.
+    fn buffer_to_bytes<'a>(
+        plugin_memory: &'a Memory,
+        store: impl AsContext<Data = WasiCtxAlloc> + 'a,
+        buffer: &WasiBuffer,
+    ) -> Result<Vec<u8>, Error> {
+        let buffer_start = buffer.ptr as usize;
+        let buffer_end = buffer_start + buffer.len as usize;
+
+        // read the buffer at this point into a byte array
+        // deserialize the byte array into the provided serde type
+        let result = plugin_memory.data(store.as_context())[buffer_start..buffer_end].to_vec();
+        Ok(result)
+    }
+
     async fn buffer_to_free(
         free_buffer: TypedFunc<u64, ()>,
         mut store: impl AsContextMut<Data = WasiCtxAlloc>,