lsp: Handle responses in background thread (#12640)

Bennet Bo Fenner and Piotr created

Release Notes:

- Improved performance when handling large responses from language
servers

---------

Co-authored-by: Piotr <piotr@zed.dev>

Change summary

crates/lsp/src/input_handler.rs | 159 ++++++++++++++++++++++++++++++++++
crates/lsp/src/lsp.rs           | 161 ++++++----------------------------
2 files changed, 189 insertions(+), 131 deletions(-)

Detailed changes

crates/lsp/src/input_handler.rs 🔗

@@ -0,0 +1,159 @@
+use std::str;
+use std::sync::Arc;
+
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use futures::{
+    channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
+    AsyncBufReadExt, AsyncRead, AsyncReadExt as _,
+};
+use gpui::{BackgroundExecutor, Task};
+use log::warn;
+use parking_lot::Mutex;
+use smol::io::BufReader;
+
+use crate::{
+    AnyNotification, AnyResponse, IoHandler, IoKind, RequestId, ResponseHandler, CONTENT_LEN_HEADER,
+};
+
+const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n";
+/// Handler for stdout of language server.
+pub struct LspStdoutHandler {
+    pub(super) loop_handle: Task<Result<()>>,
+    pub(super) notifications_channel: UnboundedReceiver<AnyNotification>,
+}
+
+pub(self) async fn read_headers<Stdout>(
+    reader: &mut BufReader<Stdout>,
+    buffer: &mut Vec<u8>,
+) -> Result<()>
+where
+    Stdout: AsyncRead + Unpin + Send + 'static,
+{
+    loop {
+        if buffer.len() >= HEADER_DELIMITER.len()
+            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
+        {
+            return Ok(());
+        }
+
+        if reader.read_until(b'\n', buffer).await? == 0 {
+            return Err(anyhow!("cannot read LSP message headers"));
+        }
+    }
+}
+
+impl LspStdoutHandler {
+    pub fn new<Input>(
+        stdout: Input,
+        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
+        cx: BackgroundExecutor,
+    ) -> Self
+    where
+        Input: AsyncRead + Unpin + Send + 'static,
+    {
+        let (tx, notifications_channel) = unbounded();
+        let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers));
+        Self {
+            loop_handle,
+            notifications_channel,
+        }
+    }
+
+    async fn handler<Input>(
+        stdout: Input,
+        notifications_sender: UnboundedSender<AnyNotification>,
+        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
+    ) -> anyhow::Result<()>
+    where
+        Input: AsyncRead + Unpin + Send + 'static,
+    {
+        let mut stdout = BufReader::new(stdout);
+
+        let mut buffer = Vec::new();
+
+        loop {
+            buffer.clear();
+
+            read_headers(&mut stdout, &mut buffer).await?;
+
+            let headers = std::str::from_utf8(&buffer)?;
+
+            let message_len = headers
+                .split('\n')
+                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
+                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
+                .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))?
+                .trim_end()
+                .parse()?;
+
+            buffer.resize(message_len, 0);
+            stdout.read_exact(&mut buffer).await?;
+
+            if let Ok(message) = str::from_utf8(&buffer) {
+                log::trace!("incoming message: {message}");
+                for handler in io_handlers.lock().values_mut() {
+                    handler(IoKind::StdOut, message);
+                }
+            }
+
+            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
+                notifications_sender.unbounded_send(msg)?;
+            } else if let Ok(AnyResponse {
+                id, error, result, ..
+            }) = serde_json::from_slice(&buffer)
+            {
+                let mut response_handlers = response_handlers.lock();
+                if let Some(handler) = response_handlers
+                    .as_mut()
+                    .and_then(|handlers| handlers.remove(&id))
+                {
+                    drop(response_handlers);
+                    if let Some(error) = error {
+                        handler(Err(error));
+                    } else if let Some(result) = result {
+                        handler(Ok(result.get().into()));
+                    } else {
+                        handler(Ok("null".into()));
+                    }
+                }
+            } else {
+                warn!(
+                    "failed to deserialize LSP message:\n{}",
+                    std::str::from_utf8(&buffer)?
+                );
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[gpui::test]
+    async fn test_read_headers() {
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
+
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(
+            buf,
+            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
+        );
+
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(
+            buf,
+            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
+        );
+    }
+}

crates/lsp/src/lsp.rs 🔗

@@ -1,4 +1,5 @@
-use log::warn;
+mod input_handler;
+
 pub use lsp_types::request::*;
 pub use lsp_types::*;
 
@@ -12,7 +13,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::{json, value::RawValue, Value};
 use smol::{
     channel,
-    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
+    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
     process::{self, Child},
 };
 
@@ -25,7 +26,6 @@ use std::{
     io::Write,
     path::PathBuf,
     pin::Pin,
-    str::{self, FromStr as _},
     sync::{
         atomic::{AtomicI32, Ordering::SeqCst},
         Arc, Weak,
@@ -36,13 +36,13 @@ use std::{
 use std::{path::Path, process::Stdio};
 use util::{ResultExt, TryFutureExt};
 
-const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n";
 const JSON_RPC_VERSION: &str = "2.0";
 const CONTENT_LEN_HEADER: &str = "Content-Length: ";
+
 const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
 
-type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, &str, AsyncAppContext)>;
+type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, Value, AsyncAppContext)>;
 type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
 type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
 
@@ -164,13 +164,12 @@ struct Notification<'a, T> {
 
 /// Language server RPC notification message before it is deserialized into a concrete type.
 #[derive(Debug, Clone, Deserialize)]
-struct AnyNotification<'a> {
+struct AnyNotification {
     #[serde(default)]
     id: Option<RequestId>,
-    #[serde(borrow)]
-    method: &'a str,
-    #[serde(borrow, default)]
-    params: Option<&'a RawValue>,
+    method: String,
+    #[serde(default)]
+    params: Option<Value>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -297,13 +296,7 @@ impl LanguageServer {
                     "Language server with id {} sent unhandled notification {}:\n{}",
                     server_id,
                     notification.method,
-                    serde_json::to_string_pretty(
-                        &notification
-                            .params
-                            .and_then(|params| Value::from_str(params.get()).ok())
-                            .unwrap_or(Value::Null)
-                    )
-                    .unwrap(),
+                    serde_json::to_string_pretty(&notification.params).unwrap(),
                 );
             },
         );
@@ -418,79 +411,36 @@ impl LanguageServer {
         Stdout: AsyncRead + Unpin + Send + 'static,
         F: FnMut(AnyNotification) + 'static + Send,
     {
-        let mut stdout = BufReader::new(stdout);
+        use smol::stream::StreamExt;
+        let stdout = BufReader::new(stdout);
         let _clear_response_handlers = util::defer({
             let response_handlers = response_handlers.clone();
             move || {
                 response_handlers.lock().take();
             }
         });
-        let mut buffer = Vec::new();
-        loop {
-            buffer.clear();
-
-            read_headers(&mut stdout, &mut buffer).await?;
-
-            let headers = std::str::from_utf8(&buffer)?;
-
-            let message_len = headers
-                .split('\n')
-                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
-                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
-                .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))?
-                .trim_end()
-                .parse()?;
-
-            buffer.resize(message_len, 0);
-            stdout.read_exact(&mut buffer).await?;
-
-            if let Ok(message) = str::from_utf8(&buffer) {
-                log::trace!("incoming message: {message}");
-                for handler in io_handlers.lock().values_mut() {
-                    handler(IoKind::StdOut, message);
-                }
-            }
+        let mut input_handler = input_handler::LspStdoutHandler::new(
+            stdout,
+            response_handlers,
+            io_handlers,
+            cx.background_executor().clone(),
+        );
 
-            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
+        while let Some(msg) = input_handler.notifications_channel.next().await {
+            {
                 let mut notification_handlers = notification_handlers.lock();
-                if let Some(handler) = notification_handlers.get_mut(msg.method) {
-                    handler(
-                        msg.id,
-                        msg.params.map(|params| params.get()).unwrap_or("null"),
-                        cx.clone(),
-                    );
+                if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) {
+                    handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone());
                 } else {
                     drop(notification_handlers);
                     on_unhandled_notification(msg);
                 }
-            } else if let Ok(AnyResponse {
-                id, error, result, ..
-            }) = serde_json::from_slice(&buffer)
-            {
-                let mut response_handlers = response_handlers.lock();
-                if let Some(handler) = response_handlers
-                    .as_mut()
-                    .and_then(|handlers| handlers.remove(&id))
-                {
-                    drop(response_handlers);
-                    if let Some(error) = error {
-                        handler(Err(error));
-                    } else if let Some(result) = result {
-                        handler(Ok(result.get().into()));
-                    } else {
-                        handler(Ok("null".into()));
-                    }
-                }
-            } else {
-                warn!(
-                    "failed to deserialize LSP message:\n{}",
-                    std::str::from_utf8(&buffer)?
-                );
             }
 
-            // Don't starve the main thread when receiving lots of messages at once.
+            // Don't starve the main thread when receiving lots of notifications at once.
             smol::future::yield_now().await;
         }
+        input_handler.loop_handle.await
     }
 
     async fn handle_stderr<Stderr>(
@@ -512,7 +462,7 @@ impl LanguageServer {
                 return Ok(());
             }
 
-            if let Ok(message) = str::from_utf8(&buffer) {
+            if let Ok(message) = std::str::from_utf8(&buffer) {
                 log::trace!("incoming stderr message:{message}");
                 for handler in io_handlers.lock().values_mut() {
                     handler(IoKind::StdErr, message);
@@ -850,7 +800,7 @@ impl LanguageServer {
         let prev_handler = self.notification_handlers.lock().insert(
             method,
             Box::new(move |_, params, cx| {
-                if let Some(params) = serde_json::from_str(params).log_err() {
+                if let Some(params) = serde_json::from_value(params).log_err() {
                     f(params, cx);
                 }
             }),
@@ -878,7 +828,7 @@ impl LanguageServer {
             method,
             Box::new(move |id, params, cx| {
                 if let Some(id) = id {
-                    match serde_json::from_str(params) {
+                    match serde_json::from_value(params) {
                         Ok(params) => {
                             let response = f(params, cx.clone());
                             cx.foreground_executor()
@@ -910,12 +860,7 @@ impl LanguageServer {
                         }
 
                         Err(error) => {
-                            log::error!(
-                                "error deserializing {} request: {:?}, message: {:?}",
-                                method,
-                                error,
-                                params
-                            );
+                            log::error!("error deserializing {} request: {:?}", method, error);
                             let response = AnyResponse {
                                 jsonrpc: JSON_RPC_VERSION,
                                 id,
@@ -1202,10 +1147,7 @@ impl FakeLanguageServer {
                         notifications_tx
                             .try_send((
                                 msg.method.to_string(),
-                                msg.params
-                                    .map(|raw_value| raw_value.get())
-                                    .unwrap_or("null")
-                                    .to_string(),
+                                msg.params.unwrap_or(Value::Null).to_string(),
                             ))
                             .ok();
                     },
@@ -1372,30 +1314,11 @@ impl FakeLanguageServer {
     }
 }
 
-pub(self) async fn read_headers<Stdout>(
-    reader: &mut BufReader<Stdout>,
-    buffer: &mut Vec<u8>,
-) -> Result<()>
-where
-    Stdout: AsyncRead + Unpin + Send + 'static,
-{
-    loop {
-        if buffer.len() >= HEADER_DELIMITER.len()
-            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
-        {
-            return Ok(());
-        }
-
-        if reader.read_until(b'\n', buffer).await? == 0 {
-            return Err(anyhow!("cannot read LSP message headers"));
-        }
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
     use gpui::TestAppContext;
+    use std::str::FromStr;
 
     #[ctor::ctor]
     fn init_logger() {
@@ -1475,30 +1398,6 @@ mod tests {
         fake.receive_notification::<notification::Exit>().await;
     }
 
-    #[gpui::test]
-    async fn test_read_headers() {
-        let mut buf = Vec::new();
-        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
-        read_headers(&mut reader, &mut buf).await.unwrap();
-        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
-
-        let mut buf = Vec::new();
-        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
-        read_headers(&mut reader, &mut buf).await.unwrap();
-        assert_eq!(
-            buf,
-            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
-        );
-
-        let mut buf = Vec::new();
-        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
-        read_headers(&mut reader, &mut buf).await.unwrap();
-        assert_eq!(
-            buf,
-            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
-        );
-    }
-
     #[gpui::test]
     fn test_deserialize_string_digit_id() {
         let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;