Factor out buffer code

Isaac Clayton created

Change summary

crates/plugin_macros/src/lib.rs             |   6 -
crates/plugin_runtime/src/wasi.rs           | 104 +++++++++++---------
crates/zed/src/languages/language_plugin.rs |  15 +-
plugins/json_language/src/lib.rs            | 114 ++++++++++------------
4 files changed, 115 insertions(+), 124 deletions(-)

Detailed changes

crates/plugin_macros/src/lib.rs 🔗

@@ -61,11 +61,7 @@ pub fn bind(args: TokenStream, function: TokenStream) -> TokenStream {
             // operation
             let data: #ty = match ::plugin::bincode::deserialize(&data) {
                 Ok(d) => d,
-                Err(e) => {
-                    println!("data: {:?}", data);
-                    println!("error: {}", e);
-                    panic!("Data passed to function not deserializable.")
-                },
+                Err(e) => panic!("Data passed to function not deserializable."),
             };
             let result = #inner_fn_name(#args);
             let new_data: Result<Vec<u8>, _> = ::plugin::bincode::serialize(&result);

crates/plugin_runtime/src/wasi.rs 🔗

@@ -14,8 +14,8 @@ pub struct Wasi {
     module: Module,
     store: Store<WasiCtx>,
     instance: Instance,
-    alloc_buffer: TypedFunc<i32, i32>,
-    // free_buffer: TypedFunc<(i32, i32), ()>,
+    alloc_buffer: TypedFunc<u32, u32>,
+    // free_buffer: TypedFunc<(u32, u32), ()>,
 }
 
 pub struct WasiPlugin {
@@ -54,27 +54,20 @@ impl Wasi {
         let engine = Engine::default();
         let mut linker = Linker::new(&engine);
 
-        linker.func_wrap("env", "hello", |x: u32| x * 2).unwrap();
-        linker.func_wrap("env", "bye", |x: u32| x / 2).unwrap();
+        linker.func_wrap("env", "__hello", |x: u32| x * 2).unwrap();
+        linker.func_wrap("env", "__bye", |x: u32| x / 2).unwrap();
 
-        println!("linking");
         wasmtime_wasi::add_to_linker(&mut linker, |s| s)?;
 
-        println!("linked");
         let mut store: Store<_> = Store::new(&engine, plugin.wasi_ctx);
-        println!("moduling");
         let module = Module::new(&engine, plugin.module)?;
-        println!("moduled");
 
         linker.module(&mut store, "", &module)?;
-        println!("linked again");
-
         let instance = linker.instantiate(&mut store, &module)?;
-        println!("instantiated");
 
         let alloc_buffer = instance.get_typed_func(&mut store, "__alloc_buffer")?;
         // let free_buffer = instance.get_typed_func(&mut store, "__free_buffer")?;
-        println!("can alloc");
+
         Ok(Wasi {
             engine,
             module,
@@ -99,7 +92,6 @@ impl Wasi {
 
         // grab an empty file descriptor, specify capabilities
         let fd = ctx.table().push(Box::new(()))?;
-        dbg!(fd);
         let caps = dir::DirCaps::all();
         let file_caps = file::FileCaps::all();
 
@@ -172,62 +164,78 @@ impl Wasi {
     // This isn't a problem because Wasm stops executing after the function returns,
     // so the heap is still valid for our inspection when we want to pull things out.
 
-    // TODO: dont' use as for conversions
-    pub fn call<A: Serialize, R: DeserializeOwned>(
-        &mut self,
-        handle: &str,
-        arg: A,
-    ) -> Result<R, Error> {
-        dbg!(&handle);
-        // dbg!(serde_json::to_string(&arg)).unwrap();
-
+    /// Takes an item, allocates a buffer, serializes the argument to that buffer,
+    /// and returns a (ptr, len) pair to that buffer.
+    fn serialize_to_buffer<T: Serialize>(&mut self, item: T) -> Result<(u32, u32), Error> {
         // serialize the argument using bincode
-        let arg = bincode::serialize(&arg)?;
-        let arg_buffer_len = arg.len();
+        let item = bincode::serialize(&item)?;
+        let buffer_len = item.len() as u32;
 
         // allocate a buffer and write the argument to that buffer
-        let arg_buffer_ptr = self
-            .alloc_buffer
-            .call(&mut self.store, arg_buffer_len as i32)?;
+        let buffer_ptr = self.alloc_buffer.call(&mut self.store, buffer_len)?;
         let plugin_memory = self
             .instance
             .get_memory(&mut self.store, "memory")
             .ok_or_else(|| anyhow!("Could not grab slice of plugin memory"))?;
-        plugin_memory.write(&mut self.store, arg_buffer_ptr as usize, &arg)?;
-
-        // get the webassembly function we want to actually call
-        // TODO: precompute handle
-        let fun_name = format!("__{}", handle);
-        let fun = self
-            .instance
-            .get_typed_func::<(i32, i32), i32, _>(&mut self.store, &fun_name)?;
-
-        // call the function, passing in the buffer and its length
-        // this should return a pointer to a (ptr, lentgh) pair
-        let arg_buffer = (arg_buffer_ptr, arg_buffer_len as i32);
-        let result_buffer = fun.call(&mut self.store, arg_buffer)?;
+        plugin_memory.write(&mut self.store, buffer_ptr as usize, &item)?;
+        Ok((buffer_ptr, buffer_len))
+    }
 
+    /// Takes a ptr to a (ptr, len) pair and returns the corresponding deserialized buffer
+    fn deserialize_from_buffer<R: DeserializeOwned>(&mut self, buffer: u32) -> Result<R, Error> {
         // create a buffer to read the (ptr, length) pair into
         // this is a total of 4 + 4 = 8 bytes.
-        let buffer = &mut [0; 8];
-        plugin_memory.read(&mut self.store, result_buffer as usize, buffer)?;
+        let raw_buffer = &mut [0; 8];
+        let plugin_memory = self
+            .instance
+            .get_memory(&mut self.store, "memory")
+            .ok_or_else(|| anyhow!("Could not grab slice of plugin memory"))?;
+        plugin_memory.read(&mut self.store, buffer as usize, raw_buffer)?;
 
         // use these bytes (wasm stores things little-endian)
         // to get a pointer to the buffer and its length
-        let b = buffer;
-        let result_buffer_ptr = u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize;
-        let result_buffer_len = u32::from_le_bytes([b[4], b[5], b[6], b[7]]) as usize;
-        let result_buffer_end = result_buffer_ptr + result_buffer_len;
+        let b = raw_buffer;
+        let buffer_ptr = u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize;
+        let buffer_len = u32::from_le_bytes([b[4], b[5], b[6], b[7]]) as usize;
+        let buffer_end = buffer_ptr + buffer_len;
 
         // read the buffer at this point into a byte array
         // deserialize the byte array into the provided serde type
-        let result = &plugin_memory.data(&mut self.store)[result_buffer_ptr..result_buffer_end];
+        let result = &plugin_memory.data(&mut self.store)[buffer_ptr..buffer_end];
         let result = bincode::deserialize(result)?;
 
         // TODO: this is handled wasm-side, but I'd like to double-check
         // // deallocate the argument buffer
         // self.free_buffer.call(&mut self.store, arg_buffer);
 
-        return Ok(result);
+        Ok(result)
+    }
+
+    // TODO: dont' use as for conversions
+    pub fn call<A: Serialize, R: DeserializeOwned>(
+        &mut self,
+        handle: &str,
+        arg: A,
+    ) -> Result<R, Error> {
+        let start = std::time::Instant::now();
+        dbg!(&handle);
+        // dbg!(serde_json::to_string(&arg)).unwrap();
+
+        // write the argument to linear memory
+        // this returns a (ptr, lentgh) pair
+        let arg_buffer = self.serialize_to_buffer(arg)?;
+
+        // get the webassembly function we want to actually call
+        // TODO: precompute handle
+        let fun_name = format!("__{}", handle);
+        let fun = self
+            .instance
+            .get_typed_func::<(u32, u32), u32, _>(&mut self.store, &fun_name)?;
+
+        // call the function, passing in the buffer and its length
+        // this returns a ptr to a (ptr, lentgh) pair
+        let result_buffer = fun.call(&mut self.store, arg_buffer)?;
+
+        self.deserialize_from_buffer(result_buffer)
     }
 }

crates/zed/src/languages/language_plugin.rs 🔗

@@ -125,13 +125,12 @@ impl LspAdapter for LanguagePluginLspAdapter {
     }
 
     fn initialization_options(&self) -> Option<serde_json::Value> {
-        // self.runtime
-        //     .lock()
-        //     .call::<_, Option<serde_json::Value>>("initialization_options", ())
-        //     .unwrap()
-
-        Some(json!({
-            "provideFormatter": true
-        }))
+        let string = self
+            .runtime
+            .lock()
+            .call::<_, Option<String>>("initialization_options", ())
+            .unwrap()?;
+
+        serde_json::from_str(&string).ok()
     }
 }

plugins/json_language/src/lib.rs 🔗

@@ -17,12 +17,15 @@ extern "C" {
 
 // }
 
+const BIN_PATH: &'static str =
+    "node_modules/vscode-json-languageserver/bin/vscode-json-languageserver";
+
 #[bind]
 pub fn name() -> &'static str {
-    let number = unsafe { hello(27) };
-    println!("got: {}", number);
-    let number = unsafe { bye(28) };
-    println!("got: {}", number);
+    // let number = unsafe { hello(27) };
+    // println!("got: {}", number);
+    // let number = unsafe { bye(28) };
+    // println!("got: {}", number);
     "vscode-json-languageserver"
 }
 
@@ -33,79 +36,66 @@ pub fn server_args() -> Vec<String> {
 
 #[bind]
 pub fn fetch_latest_server_version() -> Option<String> {
-    // #[derive(Deserialize)]
-    // struct NpmInfo {
-    //     versions: Vec<String>,
-    // }
-
-    // let output = command("npm info vscode-json-languageserver --json")?;
-    // if !output.status.success() {
-    //     return None;
-    // }
-
-    // let mut info: NpmInfo = serde_json::from_slice(&output.stdout)?;
-    // info.versions.pop()
-    println!("fetching server version");
-    Some("1.3.4".into())
+    #[derive(Deserialize)]
+    struct NpmInfo {
+        versions: Vec<String>,
+    }
+
+    let output = command("npm info vscode-json-languageserver --json")?;
+    if !output.status.success() {
+        return None;
+    }
+
+    let mut info: NpmInfo = serde_json::from_slice(&output.stdout)?;
+    info.versions.pop()
 }
 
 #[bind]
-pub fn fetch_server_binary(container_dir: PathBuf, version: String) -> Option<PathBuf> {
-    println!("Fetching server binary");
-    return None;
-    // let version_dir = container_dir.join(version.as_str());
-    // fs::create_dir_all(&version_dir)
-    //     .await
-    //     .context("failed to create version directory")?;
-    // let binary_path = version_dir.join(Self::BIN_PATH);
-
-    // if fs::metadata(&binary_path).await.is_err() {
-    //     let output = smol::process::Command::new("npm")
-    //         .current_dir(&version_dir)
-    //         .arg("install")
-    //         .arg(format!("vscode-json-languageserver@{}", version))
-    //         .output()
-    //         .await
-    //         .context("failed to run npm install")?;
-    //     if !output.status.success() {
-    //         Err(anyhow!("failed to install vscode-json-languageserver"))?;
-    //     }
-
-    //     if let Some(mut entries) = fs::read_dir(&container_dir).await.log_err() {
-    //         while let Some(entry) = entries.next().await {
-    //             if let Some(entry) = entry.log_err() {
-    //                 let entry_path = entry.path();
-    //                 if entry_path.as_path() != version_dir {
-    //                     fs::remove_dir_all(&entry_path).await.log_err();
-    //                 }
-    //             }
-    //         }
-    //     }
-    // }
-
-    // Ok(binary_path)
-}
+pub fn fetch_server_binary(container_dir: PathBuf, version: String) -> Result<PathBuf, String> {
+    let version_dir = container_dir.join(version.as_str());
+    fs::create_dir_all(&version_dir)
+        .or_or_else(|| "failed to create version directory".to_string())?;
+    let binary_path = version_dir.join(Self::BIN_PATH);
+
+    if fs::metadata(&binary_path).await.is_err() {
+        let output = command(format!(
+            "npm install vscode-json-languageserver@{}",
+            version
+        ));
+        if !output.status.success() {
+            Err(anyhow!("failed to install vscode-json-languageserver"))?;
+        }
 
-const BIN_PATH: &'static str =
-    "node_modules/vscode-json-languageserver/bin/vscode-json-languageserver";
+        if let Some(mut entries) = fs::read_dir(&container_dir).await.log_err() {
+            while let Some(entry) = entries.next().await {
+                if let Some(entry) = entry.log_err() {
+                    let entry_path = entry.path();
+                    if entry_path.as_path() != version_dir {
+                        fs::remove_dir_all(&entry_path).await.log_err();
+                    }
+                }
+            }
+        }
+    }
+
+    Ok(binary_path)
+}
 
 #[bind]
 pub fn cached_server_binary(container_dir: PathBuf) -> Option<PathBuf> {
-    println!("Finding cached server binary...");
     let mut last_version_dir = None;
-    println!("{}", container_dir.exists());
     let mut entries = fs::read_dir(&container_dir).ok()?;
-    println!("Read Entries...");
+
     while let Some(entry) = entries.next() {
         let entry = entry.ok()?;
         if entry.file_type().ok()?.is_dir() {
             last_version_dir = Some(entry.path());
         }
     }
+
     let last_version_dir = last_version_dir?;
     let bin_path = last_version_dir.join(BIN_PATH);
     if bin_path.exists() {
-        println!("this is the path: {}", bin_path.display());
         Some(bin_path)
     } else {
         None
@@ -113,10 +103,8 @@ pub fn cached_server_binary(container_dir: PathBuf) -> Option<PathBuf> {
 }
 
 #[bind]
-pub fn initialization_options() -> Option<serde_json::Value> {
-    Some(json!({
-        "provideFormatter": true
-    }))
+pub fn initialization_options() -> Option<String> {
+    Some("{ \"provideFormatter\": true }".to_string())
 }
 
 #[bind]