Add LanguageServer::on_io method, for observing JSON sent back and forth

Max Brunsfeld created

Change summary

crates/lsp/src/lsp.rs | 110 +++++++++++++++++++++++++++++++++-----------
1 file changed, 82 insertions(+), 28 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -20,10 +20,10 @@ use std::{
     future::Future,
     io::Write,
     path::PathBuf,
-    str::FromStr,
+    str::{self, FromStr as _},
     sync::{
         atomic::{AtomicUsize, Ordering::SeqCst},
-        Arc,
+        Arc, Weak,
     },
 };
 use std::{path::Path, process::Stdio};
@@ -34,16 +34,18 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: ";
 
 type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
 type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
+type IoHandler = Box<dyn Send + FnMut(bool, &str)>;
 
 pub struct LanguageServer {
     server_id: LanguageServerId,
     next_id: AtomicUsize,
-    outbound_tx: channel::Sender<Vec<u8>>,
+    outbound_tx: channel::Sender<String>,
     name: String,
     capabilities: ServerCapabilities,
     code_action_kinds: Option<Vec<CodeActionKind>>,
     notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
     response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+    io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
     executor: Arc<executor::Background>,
     #[allow(clippy::type_complexity)]
     io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@@ -56,9 +58,16 @@ pub struct LanguageServer {
 #[repr(transparent)]
 pub struct LanguageServerId(pub usize);
 
-pub struct Subscription {
-    method: &'static str,
-    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+pub enum Subscription {
+    Detached,
+    Notification {
+        method: &'static str,
+        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+    },
+    Io {
+        id: usize,
+        io_handlers: Weak<Mutex<HashMap<usize, IoHandler>>>,
+    },
 }
 
 #[derive(Serialize, Deserialize)]
@@ -177,33 +186,40 @@ impl LanguageServer {
         Stdout: AsyncRead + Unpin + Send + 'static,
         F: FnMut(AnyNotification) + 'static + Send,
     {
-        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
+        let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
+        let (output_done_tx, output_done_rx) = barrier::channel();
         let notification_handlers =
             Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
         let response_handlers =
             Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+        let io_handlers = Arc::new(Mutex::new(HashMap::default()));
         let input_task = cx.spawn(|cx| {
-            let notification_handlers = notification_handlers.clone();
-            let response_handlers = response_handlers.clone();
             Self::handle_input(
                 stdout,
                 on_unhandled_notification,
-                notification_handlers,
-                response_handlers,
+                notification_handlers.clone(),
+                response_handlers.clone(),
+                io_handlers.clone(),
                 cx,
             )
             .log_err()
         });
-        let (output_done_tx, output_done_rx) = barrier::channel();
         let output_task = cx.background().spawn({
-            let response_handlers = response_handlers.clone();
-            Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
+            Self::handle_output(
+                stdin,
+                outbound_rx,
+                output_done_tx,
+                response_handlers.clone(),
+                io_handlers.clone(),
+            )
+            .log_err()
         });
 
         Self {
             server_id,
             notification_handlers,
             response_handlers,
+            io_handlers,
             name: Default::default(),
             capabilities: Default::default(),
             code_action_kinds,
@@ -226,6 +242,7 @@ impl LanguageServer {
         mut on_unhandled_notification: F,
         notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
         response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
         cx: AsyncAppContext,
     ) -> anyhow::Result<()>
     where
@@ -252,7 +269,13 @@ impl LanguageServer {
 
             buffer.resize(message_len, 0);
             stdout.read_exact(&mut buffer).await?;
-            log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
+
+            if let Ok(message) = str::from_utf8(&buffer) {
+                log::trace!("incoming message:{}", message);
+                for handler in io_handlers.lock().values_mut() {
+                    handler(true, message);
+                }
+            }
 
             if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
                 if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
@@ -291,9 +314,10 @@ impl LanguageServer {
 
     async fn handle_output<Stdin>(
         stdin: Stdin,
-        outbound_rx: channel::Receiver<Vec<u8>>,
+        outbound_rx: channel::Receiver<String>,
         output_done_tx: barrier::Sender,
         response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
     ) -> anyhow::Result<()>
     where
         Stdin: AsyncWrite + Unpin + Send + 'static,
@@ -307,13 +331,17 @@ impl LanguageServer {
         });
         let mut content_len_buffer = Vec::new();
         while let Ok(message) = outbound_rx.recv().await {
-            log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
+            log::trace!("outgoing message:{}", message);
+            for handler in io_handlers.lock().values_mut() {
+                handler(false, &message);
+            }
+
             content_len_buffer.clear();
             write!(content_len_buffer, "{}", message.len()).unwrap();
             stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
             stdin.write_all(&content_len_buffer).await?;
             stdin.write_all("\r\n\r\n".as_bytes()).await?;
-            stdin.write_all(&message).await?;
+            stdin.write_all(message.as_bytes()).await?;
             stdin.flush().await?;
         }
         drop(output_done_tx);
@@ -464,6 +492,19 @@ impl LanguageServer {
         self.on_custom_request(T::METHOD, f)
     }
 
+    #[must_use]
+    pub fn on_io<F>(&self, f: F) -> Subscription
+    where
+        F: 'static + Send + FnMut(bool, &str),
+    {
+        let id = self.next_id.fetch_add(1, SeqCst);
+        self.io_handlers.lock().insert(id, Box::new(f));
+        Subscription::Io {
+            id,
+            io_handlers: Arc::downgrade(&self.io_handlers),
+        }
+    }
+
     pub fn remove_request_handler<T: request::Request>(&self) {
         self.notification_handlers.lock().remove(T::METHOD);
     }
@@ -490,7 +531,7 @@ impl LanguageServer {
             prev_handler.is_none(),
             "registered multiple handlers for the same LSP method"
         );
-        Subscription {
+        Subscription::Notification {
             method,
             notification_handlers: self.notification_handlers.clone(),
         }
@@ -537,7 +578,7 @@ impl LanguageServer {
                                             },
                                         };
                                         if let Some(response) =
-                                            serde_json::to_vec(&response).log_err()
+                                            serde_json::to_string(&response).log_err()
                                         {
                                             outbound_tx.try_send(response).ok();
                                         }
@@ -560,7 +601,7 @@ impl LanguageServer {
                                     message: error.to_string(),
                                 }),
                             };
-                            if let Some(response) = serde_json::to_vec(&response).log_err() {
+                            if let Some(response) = serde_json::to_string(&response).log_err() {
                                 outbound_tx.try_send(response).ok();
                             }
                         }
@@ -572,7 +613,7 @@ impl LanguageServer {
             prev_handler.is_none(),
             "registered multiple handlers for the same LSP method"
         );
-        Subscription {
+        Subscription::Notification {
             method,
             notification_handlers: self.notification_handlers.clone(),
         }
@@ -612,14 +653,14 @@ impl LanguageServer {
     fn request_internal<T: request::Request>(
         next_id: &AtomicUsize,
         response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
-        outbound_tx: &channel::Sender<Vec<u8>>,
+        outbound_tx: &channel::Sender<String>,
         params: T::Params,
     ) -> impl 'static + Future<Output = Result<T::Result>>
     where
         T::Result: 'static + Send,
     {
         let id = next_id.fetch_add(1, SeqCst);
-        let message = serde_json::to_vec(&Request {
+        let message = serde_json::to_string(&Request {
             jsonrpc: JSON_RPC_VERSION,
             id,
             method: T::METHOD,
@@ -662,10 +703,10 @@ impl LanguageServer {
     }
 
     fn notify_internal<T: notification::Notification>(
-        outbound_tx: &channel::Sender<Vec<u8>>,
+        outbound_tx: &channel::Sender<String>,
         params: T::Params,
     ) -> Result<()> {
-        let message = serde_json::to_vec(&Notification {
+        let message = serde_json::to_string(&Notification {
             jsonrpc: JSON_RPC_VERSION,
             method: T::METHOD,
             params,
@@ -686,7 +727,7 @@ impl Drop for LanguageServer {
 
 impl Subscription {
     pub fn detach(mut self) {
-        self.method = "";
+        *(&mut self) = Self::Detached;
     }
 }
 
@@ -698,7 +739,20 @@ impl fmt::Display for LanguageServerId {
 
 impl Drop for Subscription {
     fn drop(&mut self) {
-        self.notification_handlers.lock().remove(self.method);
+        match self {
+            Subscription::Detached => {}
+            Subscription::Notification {
+                method,
+                notification_handlers,
+            } => {
+                notification_handlers.lock().remove(method);
+            }
+            Subscription::Io { id, io_handlers } => {
+                if let Some(io_handlers) = io_handlers.upgrade() {
+                    io_handlers.lock().remove(id);
+                }
+            }
+        }
     }
 }