Remove dependency on self in call-related functions

Isaac Clayton created

Change summary

crates/plugin_runtime/src/wasi.rs | 95 ++++++++++++++++++++------------
1 file changed, 59 insertions(+), 36 deletions(-)

Detailed changes

crates/plugin_runtime/src/wasi.rs 🔗

@@ -6,8 +6,8 @@ use anyhow::{anyhow, Error};
 use serde::{de::DeserializeOwned, Serialize};
 
 use wasi_common::{dir, file};
-use wasmtime::IntoFunc;
 use wasmtime::{Caller, Config, Engine, Instance, Linker, Module, Store, TypedFunc};
+use wasmtime::{IntoFunc, Memory};
 use wasmtime_wasi::{Dir, WasiCtx, WasiCtxBuilder};
 
 pub struct WasiResource(u32);
@@ -106,12 +106,15 @@ impl WasiPluginBuilder {
         name: &str,
         function: impl Fn(A) -> R + Send + Sync + 'static,
     ) -> Result<Self, Error> {
-        self.linker
-            .func_wrap("env", name, move |ptr: u32, len: u32| {
+        self.linker.func_wrap(
+            "env",
+            name,
+            move |ctx: Caller<'_, WasiCtx>, ptr: u32, len: u32| {
                 // TODO: insert serialization code
                 function(todo!());
                 7u32
-            })?;
+            },
+        )?;
         Ok(self)
     }
 
@@ -266,45 +269,56 @@ impl Wasi {
 
     /// Takes an item, allocates a buffer, serializes the argument to that buffer,
     /// and returns a (ptr, len) pair to that buffer.
-    async fn serialize_to_buffer<T: Serialize>(&mut self, item: T) -> Result<(u32, u32), Error> {
+    async fn serialize_to_buffer<T: Serialize>(
+        alloc_buffer: TypedFunc<u32, u32>,
+        plugin_memory: &mut Memory,
+        mut store: &mut Store<WasiCtx>,
+        item: T,
+    ) -> Result<(u32, u32), Error> {
         // serialize the argument using bincode
         let item = bincode::serialize(&item)?;
         let buffer_len = item.len() as u32;
 
         // allocate a buffer and write the argument to that buffer
-        let buffer_ptr = self
-            .alloc_buffer
-            .call_async(&mut self.store, buffer_len)
-            .await?;
-        let plugin_memory = self
-            .instance
-            .get_memory(&mut self.store, "memory")
-            .ok_or_else(|| anyhow!("Could not grab slice of plugin memory"))?;
-        plugin_memory.write(&mut self.store, buffer_ptr as usize, &item)?;
+        let buffer_ptr = alloc_buffer.call_async(&mut store, buffer_len).await?;
+        plugin_memory.write(&mut store, buffer_ptr as usize, &item)?;
         Ok((buffer_ptr, buffer_len))
     }
 
-    /// Takes a ptr to a (ptr, len) pair and returns the corresponding deserialized buffer
-    fn deserialize_from_buffer<R: DeserializeOwned>(&mut self, buffer: u32) -> Result<R, Error> {
+    /// Takes `ptr to a `(ptr, len)` pair, and returns `(ptr, len)`.
+    fn deref_buffer(
+        plugin_memory: &mut Memory,
+        store: &mut Store<WasiCtx>,
+        buffer: u32,
+    ) -> Result<(u32, u32), Error> {
         // create a buffer to read the (ptr, length) pair into
         // this is a total of 4 + 4 = 8 bytes.
         let raw_buffer = &mut [0; 8];
-        let plugin_memory = self
-            .instance
-            .get_memory(&mut self.store, "memory")
-            .ok_or_else(|| anyhow!("Could not grab slice of plugin memory"))?;
-        plugin_memory.read(&mut self.store, buffer as usize, raw_buffer)?;
+        plugin_memory.read(store, buffer as usize, raw_buffer)?;
 
         // use these bytes (wasm stores things little-endian)
         // to get a pointer to the buffer and its length
         let b = raw_buffer;
-        let buffer_ptr = u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize;
-        let buffer_len = u32::from_le_bytes([b[4], b[5], b[6], b[7]]) as usize;
+        let buffer_ptr = u32::from_le_bytes([b[0], b[1], b[2], b[3]]);
+        let buffer_len = u32::from_le_bytes([b[4], b[5], b[6], b[7]]);
+
+        return Ok((buffer_ptr, buffer_len));
+    }
+
+    /// Takes a `(ptr, len)` pair and returns the corresponding deserialized buffer.
+    fn deserialize_from_buffer<R: DeserializeOwned>(
+        plugin_memory: &mut Memory,
+        store: &mut Store<WasiCtx>,
+        buffer_ptr: u32,
+        buffer_len: u32,
+    ) -> Result<R, Error> {
+        let buffer_ptr = buffer_ptr as usize;
+        let buffer_len = buffer_len as usize;
         let buffer_end = buffer_ptr + buffer_len;
 
         // read the buffer at this point into a byte array
         // deserialize the byte array into the provided serde type
-        let result = &plugin_memory.data(&mut self.store)[buffer_ptr..buffer_end];
+        let result = &plugin_memory.data(store)[buffer_ptr..buffer_end];
         let result = bincode::deserialize(result)?;
 
         // TODO: this is handled wasm-side, but I'd like to double-check
@@ -337,22 +351,31 @@ impl Wasi {
         // dbg!(&handle.name);
         // dbg!(serde_json::to_string(&arg)).unwrap();
 
+        let mut plugin_memory = self
+            .instance
+            .get_memory(&mut self.store, "memory")
+            .ok_or_else(|| anyhow!("Could not grab slice of plugin memory"))?;
+
         // write the argument to linear memory
         // this returns a (ptr, lentgh) pair
-        let arg_buffer = self.serialize_to_buffer(arg).await?;
-
-        // get the webassembly function we want to actually call
-        // TODO: precompute handle
-        // let fun_name = format!("__{}", handle);
-        // let fun = self
-        //     .instance
-        //     .get_typed_func::<(u32, u32), u32, _>(&mut self.store, &fun_name)?;
-        let fun = handle.function;
+        let arg_buffer =
+            Self::serialize_to_buffer(self.alloc_buffer, &mut plugin_memory, &mut self.store, arg)
+                .await?;
 
         // call the function, passing in the buffer and its length
         // this returns a ptr to a (ptr, lentgh) pair
-        let result_buffer = fun.call_async(&mut self.store, arg_buffer).await?;
-
-        self.deserialize_from_buffer(result_buffer)
+        let result_buffer = handle
+            .function
+            .call_async(&mut self.store, arg_buffer)
+            .await?;
+        let (result_buffer_ptr, result_buffer_len) =
+            Self::deref_buffer(&mut plugin_memory, &mut self.store, result_buffer)?;
+
+        Self::deserialize_from_buffer(
+            &mut plugin_memory,
+            &mut self.store,
+            result_buffer_ptr,
+            result_buffer_len,
+        )
     }
 }