Add API for handling custom requests from the language server

Max Brunsfeld created

Change summary

crates/lsp/src/lsp.rs | 65 ++++++++++++++++++++++++++++++++++++++------
1 file changed, 55 insertions(+), 10 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -4,7 +4,7 @@ use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
 use gpui::{executor, Task};
 use parking_lot::{Mutex, RwLock};
 use postage::{barrier, prelude::Stream};
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::{json, value::RawValue, Value};
 use smol::{
     channel,
@@ -29,7 +29,8 @@ pub use lsp_types::*;
 const JSON_RPC_VERSION: &'static str = "2.0";
 const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 
-type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
+type NotificationHandler =
+    Box<dyn Send + Sync + FnMut(Option<usize>, &str, &mut channel::Sender<Vec<u8>>) -> Result<()>>;
 type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 
 pub struct LanguageServer {
@@ -80,6 +81,12 @@ struct AnyResponse<'a> {
     result: Option<&'a RawValue>,
 }
 
+#[derive(Serialize)]
+struct Response<T> {
+    id: usize,
+    result: T,
+}
+
 #[derive(Serialize, Deserialize)]
 struct Notification<'a, T> {
     #[serde(borrow)]
@@ -91,6 +98,8 @@ struct Notification<'a, T> {
 
 #[derive(Deserialize)]
 struct AnyNotification<'a> {
+    #[serde(default)]
+    id: Option<usize>,
     #[serde(borrow)]
     method: &'a str,
     #[serde(borrow)]
@@ -152,6 +161,7 @@ impl LanguageServer {
             {
                 let notification_handlers = notification_handlers.clone();
                 let response_handlers = response_handlers.clone();
+                let mut outbound_tx = outbound_tx.clone();
                 async move {
                     let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
                     let mut buffer = Vec::new();
@@ -168,11 +178,13 @@ impl LanguageServer {
                         buffer.resize(message_len, 0);
                         stdout.read_exact(&mut buffer).await?;
 
-                        if let Ok(AnyNotification { method, params }) =
+                        if let Ok(AnyNotification { id, method, params }) =
                             serde_json::from_slice(&buffer)
                         {
                             if let Some(handler) = notification_handlers.write().get_mut(method) {
-                                handler(params.get());
+                                if let Err(e) = handler(id, params.get(), &mut outbound_tx) {
+                                    log::error!("error handling {} message: {:?}", method, e);
+                                }
                             } else {
                                 log::info!(
                                     "unhandled notification {}:\n{}",
@@ -351,12 +363,11 @@ impl LanguageServer {
     {
         let prev_handler = self.notification_handlers.write().insert(
             T::METHOD,
-            Box::new(
-                move |notification| match serde_json::from_str(notification) {
-                    Ok(notification) => f(notification),
-                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
-                },
-            ),
+            Box::new(move |_, params, _| {
+                let params = serde_json::from_str(params)?;
+                f(params);
+                Ok(())
+            }),
         );
 
         assert!(
@@ -370,6 +381,40 @@ impl LanguageServer {
         }
     }
 
+    pub fn on_custom_request<Params, Resp, F>(
+        &mut self,
+        method: &'static str,
+        mut f: F,
+    ) -> Subscription
+    where
+        F: 'static + Send + Sync + FnMut(Params) -> Result<Resp>,
+        Params: DeserializeOwned,
+        Resp: Serialize,
+    {
+        let prev_handler = self.notification_handlers.write().insert(
+            method,
+            Box::new(move |id, params, tx| {
+                if let Some(id) = id {
+                    let params = serde_json::from_str(params)?;
+                    let result = f(params)?;
+                    let response = serde_json::to_vec(&Response { id, result })?;
+                    tx.try_send(response)?;
+                }
+                Ok(())
+            }),
+        );
+
+        assert!(
+            prev_handler.is_none(),
+            "registered multiple handlers for the same notification"
+        );
+
+        Subscription {
+            method,
+            notification_handlers: self.notification_handlers.clone(),
+        }
+    }
+
     pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
         &self.name
     }