Start working on host-side functions

Isaac Clayton created

Change summary

crates/plugin/src/lib.rs                    |   2 
crates/plugin_macros/src/lib.rs             |  74 +++++++++++++++
crates/plugin_runtime/src/wasi.rs           | 103 ++++++++++++++++++++--
crates/zed/src/languages/language_plugin.rs |  17 ++-
plugins/json_language/src/lib.rs            |  66 ++++++++++++--
5 files changed, 232 insertions(+), 30 deletions(-)

Detailed changes

crates/plugin/src/lib.rs 🔗

@@ -49,5 +49,5 @@ impl __Buffer {
 
 pub mod prelude {
     pub use super::{__Buffer, __alloc_buffer};
-    pub use plugin_macros::bind;
+    pub use plugin_macros::{export, import};
 }

crates/plugin_macros/src/lib.rs 🔗

@@ -5,15 +5,15 @@ use quote::{format_ident, quote};
 use syn::{parse_macro_input, FnArg, ItemFn, Type, Visibility};
 
 #[proc_macro_attribute]
-pub fn bind(args: TokenStream, function: TokenStream) -> TokenStream {
+pub fn export(args: TokenStream, function: TokenStream) -> TokenStream {
     if !args.is_empty() {
-        panic!("The bind attribute does not take any arguments");
+        panic!("The export attribute does not take any arguments");
     }
 
     let inner_fn = parse_macro_input!(function as ItemFn);
     if let Visibility::Public(_) = inner_fn.vis {
     } else {
-        panic!("The bind attribute only works for public functions");
+        panic!("The export attribute only works for public functions");
     }
 
     let inner_fn_name = format_ident!("{}", inner_fn.sig.ident);
@@ -53,6 +53,7 @@ pub fn bind(args: TokenStream, function: TokenStream) -> TokenStream {
         #inner_fn
 
         #[no_mangle]
+        // TODO: switch len from usize to u32?
         pub extern "C" fn #outer_fn_name(ptr: *const u8, len: usize) -> *const ::plugin::__Buffer {
             // setup
             let buffer = ::plugin::__Buffer { ptr, len };
@@ -73,3 +74,70 @@ pub fn bind(args: TokenStream, function: TokenStream) -> TokenStream {
         }
     })
 }
+
+#[proc_macro_attribute]
+pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
+    todo!()
+    //     if !args.is_empty() {
+    //         panic!("The import attribute does not take any arguments");
+    //     }
+
+    //     let inner_fn = parse_macro_input!(function as ItemFn);
+
+    //     let inner_fn_name = format_ident!("{}", inner_fn.sig.ident);
+    //     // let outer_fn_name = format_ident!("__{}", inner_fn_name);
+
+    //     let variadic = inner_fn.sig.inputs.len();
+    //     let i = (0..variadic).map(syn::Index::from);
+    //     let t: Vec<Type> = inner_fn
+    //         .sig
+    //         .inputs
+    //         .iter()
+    //         .map(|x| match x {
+    //             FnArg::Receiver(_) => {
+    //                 panic!("all arguments must have specified types, no `self` allowed")
+    //             }
+    //             FnArg::Typed(item) => *item.ty.clone(),
+    //         })
+    //         .collect();
+
+    //     // this is cursed...
+    //     let (args, ty) = if variadic != 1 {
+    //         (
+    //             quote! {
+    //                 #( data.#i ),*
+    //             },
+    //             quote! {
+    //                 ( #( #t ),* )
+    //             },
+    //         )
+    //     } else {
+    //         let ty = &t[0];
+    //         (quote! { data }, quote! { #ty })
+    //     };
+
+    //     TokenStream::from(quote! {
+    //         #[no_mangle]
+    //         #inner_fn
+
+    //         #[no_mangle]
+    //         pub extern "C" fn #outer_fn_name(ptr: *const u8, len: usize) -> *const ::plugin::__Buffer {
+    //             // setup
+    //             let buffer = ::plugin::__Buffer { ptr, len };
+    //             let data = unsafe { buffer.to_vec() };
+
+    //             // operation
+    //             let data: #ty = match ::plugin::bincode::deserialize(&data) {
+    //                 Ok(d) => d,
+    //                 Err(e) => panic!("Data passed to function not deserializable."),
+    //             };
+    //             let result = #inner_fn_name(#args);
+    //             let new_data: Result<Vec<u8>, _> = ::plugin::bincode::serialize(&result);
+    //             let new_data = new_data.unwrap();
+
+    //             // teardown
+    //             let new_buffer = unsafe { ::plugin::__Buffer::from_vec(new_data) };
+    //             return new_buffer.leak_to_heap();
+    //         }
+    //     })
+}

crates/plugin_runtime/src/wasi.rs 🔗

@@ -1,10 +1,13 @@
-use std::{fs::File, marker::PhantomData, path::Path};
+use std::{
+    collections::HashMap, fs::File, future::Future, marker::PhantomData, path::Path, pin::Pin,
+};
 
 use anyhow::{anyhow, Error};
 use serde::{de::DeserializeOwned, Serialize};
 
 use wasi_common::{dir, file};
-use wasmtime::{Config, Engine, Instance, Linker, Module, Store, TypedFunc};
+use wasmtime::IntoFunc;
+use wasmtime::{Caller, Config, Engine, Instance, Linker, Module, Store, TypedFunc};
 use wasmtime_wasi::{Dir, WasiCtx, WasiCtxBuilder};
 
 pub struct WasiResource(u32);
@@ -41,9 +44,93 @@ pub struct Wasi {
     // free_buffer: TypedFunc<(u32, u32), ()>,
 }
 
+// type signature derived from:
+// https://docs.rs/wasmtime/latest/wasmtime/struct.Linker.html#method.func_wrap2_async
+// macro_rules! dynHostFunction {
+//     () => {
+//         Box<
+//             dyn for<'a> Fn(Caller<'a, WasiCtx>, u32, u32)
+//                 -> Box<dyn Future<Output = u32> + Send + 'a>
+//                     + Send
+//                     + Sync
+//                     + 'static
+//         >
+//     };
+// }
+
+// macro_rules! implHostFunction {
+//     () => {
+//         impl for<'a> Fn(Caller<'a, WasiCtx>, u32, u32)
+//             -> Box<dyn Future<Output = u32> + Send + 'a>
+//                 + Send
+//                 + Sync
+//                 + 'static
+//     };
+// }
+
+// This type signature goodness gracious
+pub type HostFunction = Box<dyn IntoFunc<WasiCtx, (u32, u32), u32>>;
+
+pub struct WasiPluginBuilder {
+    host_functions: HashMap<String, HostFunction>,
+    wasi_ctx_builder: WasiCtxBuilder,
+}
+
+impl WasiPluginBuilder {
+    pub fn new() -> Self {
+        WasiPluginBuilder {
+            host_functions: HashMap::new(),
+            wasi_ctx_builder: WasiCtxBuilder::new(),
+        }
+    }
+
+    pub fn new_with_default_ctx() -> WasiPluginBuilder {
+        let mut this = Self::new();
+        this.wasi_ctx_builder = this.wasi_ctx_builder.inherit_stdin().inherit_stderr();
+        this
+    }
+
+    fn wrap_host_function<A: Serialize, R: DeserializeOwned>(
+        function: impl Fn(A) -> R + Send + Sync + 'static,
+    ) -> HostFunction {
+        Box::new(move |ptr, len| {
+            function(todo!());
+            todo!()
+        })
+    }
+
+    pub fn host_function<A: Serialize, R: DeserializeOwned>(
+        mut self,
+        name: &str,
+        function: impl Fn(A) -> R + Send + Sync + 'static,
+    ) -> Self {
+        self.host_functions
+            .insert(name.to_string(), Self::wrap_host_function(function));
+        self
+    }
+
+    pub fn wasi_ctx(mut self, config: impl FnOnce(WasiCtxBuilder) -> WasiCtxBuilder) -> Self {
+        self.wasi_ctx_builder = config(self.wasi_ctx_builder);
+        self
+    }
+
+    pub async fn init<T: AsRef<[u8]>>(self, module: T) -> Result<Wasi, Error> {
+        let plugin = WasiPlugin {
+            module: module.as_ref().to_vec(),
+            wasi_ctx: self.wasi_ctx_builder.build(),
+            host_functions: self.host_functions,
+        };
+
+        Wasi::init(plugin).await
+    }
+}
+
+/// Represents a to-be-initialized plugin.
+/// Please use [`WasiPluginBuilder`], don't use this directly.
 pub struct WasiPlugin {
     pub module: Vec<u8>,
     pub wasi_ctx: WasiCtx,
+    pub host_functions: HashMap<String, HostFunction>,
 }
 
 impl Wasi {
@@ -66,19 +153,15 @@ impl Wasi {
 }
 
 impl Wasi {
-    pub fn default_ctx() -> WasiCtx {
-        WasiCtxBuilder::new()
-            .inherit_stdout()
-            .inherit_stderr()
-            .build()
-    }
-
-    pub async fn init(plugin: WasiPlugin) -> Result<Self, Error> {
+    async fn init(plugin: WasiPlugin) -> Result<Self, Error> {
         let mut config = Config::default();
         config.async_support(true);
         let engine = Engine::new(&config)?;
         let mut linker = Linker::new(&engine);
 
+        linker
+            .func_wrap("env", "__command", |x: u32, y: u32| x + y)
+            .unwrap();
         linker.func_wrap("env", "__hello", |x: u32| x * 2).unwrap();
         linker.func_wrap("env", "__bye", |x: u32| x / 2).unwrap();
 

crates/zed/src/languages/language_plugin.rs 🔗

@@ -6,17 +6,21 @@ use futures::{future::BoxFuture, FutureExt, StreamExt};
 use gpui::executor::{self, Background};
 use isahc::http::version;
 use language::{LanguageServerName, LspAdapter};
-use plugin_runtime::{Wasi, WasiFn, WasiPlugin};
+use plugin_runtime::{Wasi, WasiFn, WasiPlugin, WasiPluginBuilder};
 use serde_json::json;
 use std::fs;
 use std::{any::Any, path::PathBuf, sync::Arc};
 use util::{ResultExt, TryFutureExt};
 
 pub async fn new_json(executor: Arc<Background>) -> Result<PluginLspAdapter> {
-    let plugin = WasiPlugin {
-        module: include_bytes!("../../../../plugins/bin/json_language.wasm").to_vec(),
-        wasi_ctx: Wasi::default_ctx(),
-    };
+    let plugin = WasiPluginBuilder::new_with_default_ctx()
+        .host_function("command", |command: String| {
+            // TODO: actual thing
+            std::process::Command::new(command).output().unwrap();
+            Some("Hello".to_string())
+        })
+        .init(include_bytes!("../../../../plugins/bin/json_language.wasm"))
+        .await?;
     PluginLspAdapter::new(plugin, executor).await
 }
 
@@ -33,8 +37,7 @@ pub struct PluginLspAdapter {
 }
 
 impl PluginLspAdapter {
-    pub async fn new(plugin: WasiPlugin, executor: Arc<Background>) -> Result<Self> {
-        let mut plugin = Wasi::init(plugin).await?;
+    pub async fn new(mut plugin: Wasi, executor: Arc<Background>) -> Result<Self> {
         Ok(Self {
             name: plugin.function("name")?,
             server_args: plugin.function("server_args")?,

plugins/json_language/src/lib.rs 🔗

@@ -5,8 +5,56 @@ use std::fs;
 use std::path::PathBuf;
 
 // #[import]
+// fn command(string: &str) -> Option<String>;
+
+extern "C" {
+    #[no_mangle]
+    fn __command(ptr: *const u8, len: usize) -> *const ::plugin::__Buffer;
+}
+
+// #[no_mangle]
+// // TODO: switch len from usize to u32?
+// pub extern "C" fn #outer_fn_name(ptr: *const u8, len: usize) -> *const ::plugin::__Buffer {
+//     // setup
+//     let buffer = ::plugin::__Buffer { ptr, len };
+//     let data = unsafe { buffer.to_vec() };
+
+//     // operation
+//     let data: #ty = match ::plugin::bincode::deserialize(&data) {
+//         Ok(d) => d,
+//         Err(e) => panic!("Data passed to function not deserializable."),
+//     };
+//     let result = #inner_fn_name(#args);
+//     let new_data: Result<Vec<u8>, _> = ::plugin::bincode::serialize(&result);
+//     let new_data = new_data.unwrap();
+
+//     // teardown
+//     let new_buffer = unsafe { ::plugin::__Buffer::from_vec(new_data) };
+//     return new_buffer.leak_to_heap();
+// }
+
+#[no_mangle]
 fn command(string: &str) -> Option<String> {
-    None
+    println!("executing command: {}", string);
+    // serialize data
+    let data = string;
+    let data = ::plugin::bincode::serialize(&data).unwrap();
+    let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
+    let ptr = buffer.ptr;
+    let len = buffer.len;
+    // leak data to heap
+    buffer.leak_to_heap();
+    // call extern function
+    let result = unsafe { __command(ptr, len) };
+    // get result
+    let result = todo!(); // convert into box
+
+    // deserialize data
+    let data: Option<String> = match ::plugin::bincode::deserialize(&data) {
+        Ok(d) => d,
+        Err(e) => panic!("Data passed to function not deserializable."),
+    };
+    return data;
 }
 
 // TODO: some sort of macro to generate ABI bindings
@@ -30,7 +78,7 @@ extern "C" {
 const BIN_PATH: &'static str =
     "node_modules/vscode-json-languageserver/bin/vscode-json-languageserver";
 
-#[bind]
+#[export]
 pub fn name() -> &'static str {
     // let number = unsafe { hello(27) };
     // println!("got: {}", number);
@@ -39,12 +87,12 @@ pub fn name() -> &'static str {
     "vscode-json-languageserver"
 }
 
-#[bind]
+#[export]
 pub fn server_args() -> Vec<String> {
     vec!["--stdio".into()]
 }
 
-#[bind]
+#[export]
 pub fn fetch_latest_server_version() -> Option<String> {
     #[derive(Deserialize)]
     struct NpmInfo {
@@ -61,7 +109,7 @@ pub fn fetch_latest_server_version() -> Option<String> {
     info.versions.pop()
 }
 
-#[bind]
+#[export]
 pub fn fetch_server_binary(container_dir: PathBuf, version: String) -> Result<PathBuf, String> {
     let version_dir = container_dir.join(version.as_str());
     fs::create_dir_all(&version_dir)
@@ -92,7 +140,7 @@ pub fn fetch_server_binary(container_dir: PathBuf, version: String) -> Result<Pa
     Ok(binary_path)
 }
 
-#[bind]
+#[export]
 pub fn cached_server_binary(container_dir: PathBuf) -> Option<PathBuf> {
     let mut last_version_dir = None;
     let mut entries = fs::read_dir(&container_dir).ok()?;
@@ -113,17 +161,17 @@ pub fn cached_server_binary(container_dir: PathBuf) -> Option<PathBuf> {
     }
 }
 
-#[bind]
+#[export]
 pub fn label_for_completion(label: String) -> Option<String> {
     None
 }
 
-#[bind]
+#[export]
 pub fn initialization_options() -> Option<String> {
     Some("{ \"provideFormatter\": true }".to_string())
 }
 
-#[bind]
+#[export]
 pub fn id_for_language(name: String) -> Option<String> {
     if name == "JSON" {
         Some("jsonc".into())