Add functions with multiple arguments to import macro, add test cases

Isaac Clayton created

Change summary

crates/plugin_macros/src/lib.rs  | 137 +++++++++++++++++----------------
crates/plugin_runtime/src/lib.rs |  17 +++
plugins/test_plugin/src/lib.rs   |  31 ++++++-
3 files changed, 108 insertions(+), 77 deletions(-)

Detailed changes

crates/plugin_macros/src/lib.rs 🔗

@@ -2,7 +2,9 @@ use core::panic;
 
 use proc_macro::TokenStream;
 use quote::{format_ident, quote};
-use syn::{parse_macro_input, FnArg, ForeignItemFn, ItemFn, Type, Visibility};
+use syn::{
+    parse_macro_input, Block, FnArg, ForeignItemFn, Ident, ItemFn, Pat, PatIdent, Type, Visibility,
+};
 
 #[proc_macro_attribute]
 pub fn export(args: TokenStream, function: TokenStream) -> TokenStream {
@@ -32,7 +34,7 @@ pub fn export(args: TokenStream, function: TokenStream) -> TokenStream {
         .iter()
         .map(|x| match x {
             FnArg::Receiver(_) => {
-                panic!("all arguments must have specified types, no `self` allowed")
+                panic!("All arguments must have specified types, no `self` allowed")
             }
             FnArg::Typed(item) => *item.ty.clone(),
         })
@@ -91,70 +93,69 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
         panic!("Exported functions can not take generic parameters");
     }
 
-    dbg!(&fn_declare.sig);
-
-    // let inner_fn = ItemFn {
-    //     attrs: fn_declare.attrs,
-    //     vis: fn_declare.vis,
-    //     sig: fn_declare.sig,
-    //     block: todo!(),
-    // };
-
-    let outer_fn_name = format_ident!("{}", fn_declare.sig.ident);
-    let inner_fn_name = format_ident!("__{}", outer_fn_name);
-
-    //     let variadic = inner_fn.sig.inputs.len();
-    //     let i = (0..variadic).map(syn::Index::from);
-    //     let t: Vec<Type> = inner_fn
-    //         .sig
-    //         .inputs
-    //         .iter()
-    //         .map(|x| match x {
-    //             FnArg::Receiver(_) => {
-    //                 panic!("all arguments must have specified types, no `self` allowed")
-    //             }
-    //             FnArg::Typed(item) => *item.ty.clone(),
-    //         })
-    //         .collect();
-
-    //     // this is cursed...
-    //     let (args, ty) = if variadic != 1 {
-    //         (
-    //             quote! {
-    //                 #( data.#i ),*
-    //             },
-    //             quote! {
-    //                 ( #( #t ),* )
-    //             },
-    //         )
-    //     } else {
-    //         let ty = &t[0];
-    //         (quote! { data }, quote! { #ty })
-    //     };
-
-    // TokenStream::from(quote! {
-    //     extern "C" {
-    //         fn #inner_fn_name(buffer: u64) -> u64;
-    //     }
-
-    //     #[no_mangle]
-    //     fn #outer_fn_name #args /* (string: &str) */ -> #return_type /* Option<Vec<u8>> */ {
-    //         dbg!("executing command: {}", string);
-    //         // setup
-    //         let data = #args_collect;
-    //         let data = ::plugin::bincode::serialize(&data).unwrap();
-    //         let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
-
-    //         // operation
-    //         let new_buffer = unsafe { #inner_fn_name(buffer.into_u64()) };
-    //         let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
-
-    //         // teardown
-    //         match ::plugin::bincode::deserialize(&new_data) {
-    //             Ok(d) => d,
-    //             Err(e) => panic!("Data returned from function not deserializable."),
-    //         }
-    //     }
-    // })
-    todo!()
+    // let inner_fn_name = format_ident!("{}", fn_declare.sig.ident);
+    let extern_fn_name = format_ident!("__{}", fn_declare.sig.ident);
+
+    let (args, tys): (Vec<Ident>, Vec<Type>) = fn_declare
+        .sig
+        .inputs
+        .clone()
+        .into_iter()
+        .map(|x| match x {
+            FnArg::Receiver(_) => {
+                panic!("All arguments must have specified types, no `self` allowed")
+            }
+            FnArg::Typed(t) => {
+                if let Pat::Ident(i) = *t.pat {
+                    (i.ident, *t.ty)
+                } else {
+                    panic!("All function arguments must be identifiers");
+                }
+            }
+        })
+        .unzip();
+
+    dbg!("hello");
+
+    let body = TokenStream::from(quote! {
+        {
+            // dbg!("executing imported function");
+            // setup
+            let data: (#( #tys ),*) = (#( #args ),*);
+            let data = ::plugin::bincode::serialize(&data).unwrap();
+            let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
+
+            // operation
+            let new_buffer = unsafe { #extern_fn_name(buffer.into_u64()) };
+            let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
+
+            // teardown
+            match ::plugin::bincode::deserialize(&new_data) {
+                Ok(d) => d,
+                Err(e) => panic!("Data returned from function not deserializable."),
+            }
+        }
+    });
+
+    dbg!("hello2");
+
+    let block = parse_macro_input!(body as Block);
+
+    dbg!("hello {:?}", &block);
+
+    let inner_fn = ItemFn {
+        attrs: fn_declare.attrs,
+        vis: fn_declare.vis,
+        sig: fn_declare.sig,
+        block: Box::new(block),
+    };
+
+    TokenStream::from(quote! {
+        extern "C" {
+            fn #extern_fn_name(buffer: u64) -> u64;
+        }
+
+        #[no_mangle]
+        #inner_fn
+    })
 }

crates/plugin_runtime/src/lib.rs 🔗

@@ -16,7 +16,8 @@ mod tests {
             swap: WasiFn<(u32, u32), (u32, u32)>,
             sort: WasiFn<Vec<u32>, Vec<u32>>,
             print: WasiFn<String, ()>,
-            // and_back: WasiFn<u32, u32>,
+            and_back: WasiFn<u32, u32>,
+            imports: WasiFn<u32, u32>,
         }
 
         async {
@@ -24,6 +25,12 @@ mod tests {
                 .unwrap()
                 .host_function("mystery_number", |input: u32| input + 7)
                 .unwrap()
+                .host_function("import_noop", |_: ()| ())
+                .unwrap()
+                .host_function("import_identity", |input: u32| input)
+                .unwrap()
+                .host_function("import_swap", |(a, b): (u32, u32)| (b, a))
+                .unwrap()
                 .init(include_bytes!("../../../plugins/bin/test_plugin.wasm"))
                 .await
                 .unwrap();
@@ -36,7 +43,8 @@ mod tests {
                 swap: runtime.function("swap").unwrap(),
                 sort: runtime.function("sort").unwrap(),
                 print: runtime.function("print").unwrap(),
-                // and_back: runtime.function("and_back").unwrap(),
+                and_back: runtime.function("and_back").unwrap(),
+                imports: runtime.function("imports").unwrap(),
             };
 
             let unsorted = vec![1, 3, 4, 2, 5];
@@ -49,7 +57,10 @@ mod tests {
             assert_eq!(runtime.call(&plugin.swap, (1, 2)).await.unwrap(), (2, 1));
             assert_eq!(runtime.call(&plugin.sort, unsorted).await.unwrap(), sorted);
             assert_eq!(runtime.call(&plugin.print, "Hi!".into()).await.unwrap(), ());
-            // assert_eq!(runtime.call(&plugin.and_back, 1).await.unwrap(), 8);
+            assert_eq!(runtime.call(&plugin.and_back, 1).await.unwrap(), 8);
+            assert_eq!(runtime.call(&plugin.imports, 1).await.unwrap(), 8);
+
+            // dbg!("{}", runtime.call(&plugin.and_back, 1).await.unwrap());
         }
         .block_on()
     }

plugins/test_plugin/src/lib.rs 🔗

@@ -35,10 +35,29 @@ pub fn print(string: String) {
     eprintln!("to stderr: {}", string);
 }
 
-// #[import]
-// fn mystery_number(input: u32) -> u32;
+#[import]
+fn mystery_number(input: u32) -> u32;
 
-// #[export]
-// pub fn and_back(secret: u32) -> u32 {
-//     mystery_number(secret)
-// }
+#[export]
+pub fn and_back(secret: u32) -> u32 {
+    mystery_number(secret)
+}
+
+#[import]
+fn import_noop() -> ();
+
+#[import]
+fn import_identity(i: u32) -> u32;
+
+#[import]
+fn import_swap(a: u32, b: u32) -> (u32, u32);
+
+#[export]
+pub fn imports(x: u32) -> u32 {
+    let a = import_identity(7);
+    import_noop();
+    let (b, c) = import_swap(a, x);
+    assert_eq!(a, c);
+    assert_eq!(x, b);
+    a + b // should be 7 + x
+}