Factor out serialization code

Isaac Clayton created

Change summary

crates/plugin_runtime/src/wasi.rs | 73 +++++++++++++++++++++-----------
1 file changed, 48 insertions(+), 25 deletions(-)

Detailed changes

crates/plugin_runtime/src/wasi.rs 🔗

@@ -91,31 +91,51 @@ impl WasiPluginBuilder {
         Self::new(wasi_ctx)
     }
 
-    pub fn host_function<A: DeserializeOwned, R: Serialize>(
+    pub fn host_function<A: DeserializeOwned + Send, R: Serialize + Send + Sync + Clone>(
         mut self,
         name: &str,
         function: impl Fn(A) -> R + Send + Sync + 'static,
     ) -> Result<Self, Error> {
-        self.linker.func_wrap(
+        self.linker.func_wrap2_async(
             "env",
             name,
             move |mut caller: Caller<'_, WasiCtxAlloc>, ptr: u32, len: u32| {
-                let mut plugin_memory = match caller.get_export("memory") {
-                    Some(Extern::Memory(mem)) => mem,
-                    _ => return Err(Trap::new("Could not grab slice of plugin memory")),
-                };
-                let args = Wasi::deserialize_from_buffer(&mut plugin_memory, &caller, ptr, len)?;
-
-                let result = function(args);
-                let buffer = Wasi::serialize_to_buffer(
-                    caller.data().alloc_buffer(),
-                    &mut plugin_memory,
-                    &mut caller,
-                    result,
-                )
-                .await;
-
-                Ok(7u32)
+                // TODO: use try block once avaliable
+                let result: Result<(Memory, Vec<u8>), Trap> = (|| {
+                    // grab a handle to the memory
+                    let mut plugin_memory = match caller.get_export("memory") {
+                        Some(Extern::Memory(mem)) => mem,
+                        _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
+                    };
+
+                    // get the args passed from Guest
+                    let args =
+                        Wasi::deserialize_from_buffer(&mut plugin_memory, &caller, ptr, len)?;
+
+                    // Call the Host-side function
+                    let result: R = function(args);
+
+                    // Serialize the result back to guest
+                    let result = Wasi::serialize(result).map_err(|_| {
+                        Trap::new("Could not serialize value returned from function")
+                    })?;
+                    Ok((plugin_memory, result))
+                })();
+
+                Box::new(async move {
+                    let (mut plugin_memory, result) = result?;
+
+                    // todo!();
+                    let (ptr, len) = Wasi::serialize_to_buffer(
+                        caller.data().alloc_buffer(),
+                        &mut plugin_memory,
+                        &mut caller,
+                        result,
+                    )
+                    .await?;
+
+                    Ok(7u32)
+                })
             },
         )?;
         Ok(self)
@@ -318,19 +338,22 @@ impl Wasi {
     // This isn't a problem because Wasm stops executing after the function returns,
     // so the heap is still valid for our inspection when we want to pull things out.
 
+    fn serialize<A: Serialize>(item: A) -> Result<Vec<u8>, Error> {
+        // serialize the argument using bincode
+        let item = bincode::serialize(&item)?;
+        Ok(item)
+    }
+
     /// Takes an item, allocates a buffer, serializes the argument to that buffer,
     /// and returns a (ptr, len) pair to that buffer.
-    async fn serialize_to_buffer<A: Serialize>(
+    async fn serialize_to_buffer(
         alloc_buffer: TypedFunc<u32, u32>,
         plugin_memory: &mut Memory,
         mut store: impl AsContextMut<Data = WasiCtxAlloc>,
-        item: A,
+        item: Vec<u8>,
     ) -> Result<(u32, u32), Error> {
-        // serialize the argument using bincode
-        let item = bincode::serialize(&item)?;
-        let buffer_len = item.len() as u32;
-
         // allocate a buffer and write the argument to that buffer
+        let buffer_len = item.len() as u32;
         let buffer_ptr = alloc_buffer.call_async(&mut store, buffer_len).await?;
         plugin_memory.write(&mut store, buffer_ptr as usize, &item)?;
         Ok((buffer_ptr, buffer_len))
@@ -413,7 +436,7 @@ impl Wasi {
             self.store.data().alloc_buffer(),
             &mut plugin_memory,
             &mut self.store,
-            arg,
+            Self::serialize(arg)?,
         )
         .await?;