Split out lifecycle of serialization, buffer is freed now

Isaac Clayton created

Change summary

crates/plugin_runtime/src/wasi.rs | 93 +++++++++++++++++---------------
1 file changed, 48 insertions(+), 45 deletions(-)

Detailed changes

crates/plugin_runtime/src/wasi.rs 🔗

@@ -89,34 +89,35 @@ impl WasiPluginBuilder {
             &format!("__{}", name),
             move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
                 // TODO: use try block once avaliable
-                let result: Result<(Memory, Vec<u8>), Trap> = (|| {
+                let result: Result<(WasiBuffer, Memory, Vec<u8>), Trap> = (|| {
                     // grab a handle to the memory
                     let mut plugin_memory = match caller.get_export("memory") {
                         Some(Extern::Memory(mem)) => mem,
                         _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
                     };
 
+                    let buffer = WasiBuffer::from_u64(packed_buffer);
+
                     // get the args passed from Guest
-                    let args = Wasi::deserialize_from_buffer(
-                        &mut plugin_memory,
-                        &caller,
-                        WasiBuffer::from_u64(packed_buffer),
-                    )?;
+                    let args = Wasi::buffer_to_type(&mut plugin_memory, &mut caller, &buffer)?;
 
                     // Call the Host-side function
                     let result: R = function(args);
 
                     // Serialize the result back to guest
-                    let result = Wasi::serialize(result).map_err(|_| {
+                    let result = Wasi::serialize_to_bytes(result).map_err(|_| {
                         Trap::new("Could not serialize value returned from function")
                     })?;
-                    Ok((plugin_memory, result))
+
+                    Ok((buffer, plugin_memory, result))
                 })();
 
                 Box::new(async move {
-                    let (mut plugin_memory, result) = result?;
+                    let (buffer, mut plugin_memory, result) = result?;
 
-                    let buffer = Wasi::serialize_to_buffer(
+                    Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?;
+
+                    let buffer = Wasi::bytes_to_buffer(
                         caller.data().alloc_buffer(),
                         &mut plugin_memory,
                         &mut caller,
@@ -328,15 +329,30 @@ impl Wasi {
     // This isn't a problem because Wasm stops executing after the function returns,
     // so the heap is still valid for our inspection when we want to pull things out.
 
-    fn serialize<A: Serialize>(item: A) -> Result<Vec<u8>, Error> {
+    fn serialize_to_bytes<A: Serialize>(item: A) -> Result<Vec<u8>, Error> {
         // serialize the argument using bincode
-        let item = bincode::serialize(&item)?;
-        Ok(item)
+        let bytes = bincode::serialize(&item)?;
+        Ok(bytes)
     }
 
+    // fn deserialize<R: DeserializeOwned>(
+    //     plugin_memory: &mut Memory,
+    //     mut store: impl AsContextMut<Data = WasiCtxAlloc>,
+    //     buffer: WasiBuffer,
+    // ) -> Result<R, Error> {
+    //     let buffer_start = buffer.ptr as usize;
+    //     let buffer_end = buffer_start + buffer.len as usize;
+
+    //     // read the buffer at this point into a byte array
+    //     // deserialize the byte array into the provided serde type
+    //     let item = &plugin_memory.data(store.as_context())[buffer_start..buffer_end];
+    //     let item = bincode::deserialize(bytes)?;
+    //     Ok(item)
+    // }
+
     /// Takes an item, allocates a buffer, serializes the argument to that buffer,
     /// and returns a (ptr, len) pair to that buffer.
-    async fn serialize_to_buffer(
+    async fn bytes_to_buffer(
         alloc_buffer: TypedFunc<u32, u32>,
         plugin_memory: &mut Memory,
         mut store: impl AsContextMut<Data = WasiCtxAlloc>,
@@ -349,31 +365,11 @@ impl Wasi {
         Ok(WasiBuffer { ptr, len })
     }
 
-    // /// Takes `ptr to a `(ptr, len)` pair, and returns `(ptr, len)`.
-    // fn deref_buffer(
-    //     plugin_memory: &mut Memory,
-    //     store: impl AsContext<Data = WasiCtxAlloc>,
-    //     buffer: u32,
-    // ) -> Result<(u32, u32), Error> {
-    //     // create a buffer to read the (ptr, length) pair into
-    //     // this is a total of 4 + 4 = 8 bytes.
-    //     let raw_buffer = &mut [0; 8];
-    //     plugin_memory.read(store, buffer as usize, raw_buffer)?;
-
-    //     // use these bytes (wasm stores things little-endian)
-    //     // to get a pointer to the buffer and its length
-    //     let b = raw_buffer;
-    //     let buffer_ptr = u32::from_le_bytes([b[0], b[1], b[2], b[3]]);
-    //     let buffer_len = u32::from_le_bytes([b[4], b[5], b[6], b[7]]);
-
-    //     return Ok((buffer_ptr, buffer_len));
-    // }
-
     /// Takes a `(ptr, len)` pair and returns the corresponding deserialized buffer.
-    fn deserialize_from_buffer<R: DeserializeOwned>(
-        plugin_memory: &mut Memory,
+    fn buffer_to_type<R: DeserializeOwned>(
+        plugin_memory: &Memory,
         store: impl AsContext<Data = WasiCtxAlloc>,
-        buffer: WasiBuffer,
+        buffer: &WasiBuffer,
     ) -> Result<R, Error> {
         let buffer_start = buffer.ptr as usize;
         let buffer_end = buffer_start + buffer.len as usize;
@@ -383,13 +379,20 @@ impl Wasi {
         let result = &plugin_memory.data(store.as_context())[buffer_start..buffer_end];
         let result = bincode::deserialize(result)?;
 
-        // TODO: this is handled wasm-side
-        // // deallocate the argument buffer
-        // self.free_buffer.call(&mut self.store, arg_buffer);
-
         Ok(result)
     }
 
+    async fn buffer_to_free(
+        free_buffer: TypedFunc<u64, ()>,
+        mut store: impl AsContextMut<Data = WasiCtxAlloc>,
+        buffer: WasiBuffer,
+    ) -> Result<(), Error> {
+        // deallocate the argument buffer
+        Ok(free_buffer
+            .call_async(&mut store, buffer.into_u64())
+            .await?)
+    }
+
     /// Retrieves the handle to a function of a given type.
     pub fn function<A: Serialize, R: DeserializeOwned, T: AsRef<str>>(
         &mut self,
@@ -422,11 +425,11 @@ impl Wasi {
 
         // write the argument to linear memory
         // this returns a (ptr, lentgh) pair
-        let arg_buffer = Self::serialize_to_buffer(
+        let arg_buffer = Self::bytes_to_buffer(
             self.store.data().alloc_buffer(),
             &mut plugin_memory,
             &mut self.store,
-            Self::serialize(arg)?,
+            Self::serialize_to_bytes(arg)?,
         )
         .await?;
 
@@ -437,10 +440,10 @@ impl Wasi {
             .call_async(&mut self.store, arg_buffer.into_u64())
             .await?;
 
-        Self::deserialize_from_buffer(
+        Self::buffer_to_type(
             &mut plugin_memory,
             &mut self.store,
-            WasiBuffer::from_u64(result_buffer),
+            &WasiBuffer::from_u64(result_buffer),
         )
     }
 }