Read LSP message headers at once (#7449)

Federico Dionisi and Caius created

The current LSP headers reader implementation assumes a specific order
(i.e., `Content-Length` first, and then `Content-Type`). Unfortunately,
this assumption is not always valid, as no specification enforces the
rule. @caius and I encountered this issue while implementing the
Terraform LSP, where `Content-Type` comes first, breaking the
implementation in #6929.

This PR introduces a `read_headers` function, which asynchronously reads
the incoming pipe until the headers' delimiter (i.e., '\r\n\r\n'),
adding it to the message buffer, and returning an error when delimiter's
not found.

I added a few tests but only considered scenarios where headers are
delivered at once (which should be the case?). I'm unsure if this
suffices or if I should consider more scenarios; I would love to hear
others' opinions.


Release Notes:

- N/A

---------

Co-authored-by: Caius <caius@caius.name>

Change summary

crates/lsp/src/lsp.rs | 99 +++++++++++++++++++++++++-------------------
1 file changed, 56 insertions(+), 43 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -31,6 +31,7 @@ use std::{
 use std::{path::Path, process::Stdio};
 use util::{ResultExt, TryFutureExt};
 
+const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n";
 const JSON_RPC_VERSION: &str = "2.0";
 const CONTENT_LEN_HEADER: &str = "Content-Length: ";
 const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
@@ -323,47 +324,17 @@ impl LanguageServer {
         loop {
             buffer.clear();
 
-            if stdout.read_until(b'\n', &mut buffer).await? == 0 {
-                break;
-            };
-
-            if stdout.read_until(b'\n', &mut buffer).await? == 0 {
-                break;
-            };
-
-            let header = std::str::from_utf8(&buffer)?;
-            let mut segments = header.lines();
-
-            let message_len: usize = segments
-                .next()
-                .with_context(|| {
-                    format!("unable to find the first line of the LSP message header `{header}`")
-                })?
-                .strip_prefix(CONTENT_LEN_HEADER)
-                .with_context(|| format!("invalid LSP message header `{header}`"))?
-                .parse()
-                .with_context(|| {
-                    format!("failed to parse Content-Length of LSP message header: `{header}`")
-                })?;
-
-            if let Some(second_segment) = segments.next() {
-                match second_segment {
-                    "" => (), // Header end
-                    header_field => {
-                        if header_field.starts_with("Content-Type:") {
-                            stdout.read_until(b'\n', &mut buffer).await?;
-                        } else {
-                            anyhow::bail!(
-                                "inside `{header}`, expected a Content-Type header field or a header ending CRLF, got `{second_segment:?}`"
-                            )
-                        }
-                    }
-                }
-            } else {
-                anyhow::bail!(
-                    "unable to find the second line of the LSP message header `{header}`"
-                );
-            }
+            read_headers(&mut stdout, &mut buffer).await?;
+
+            let headers = std::str::from_utf8(&buffer)?;
+
+            let message_len = headers
+                .split("\n")
+                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
+                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
+                .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))?
+                .trim_end()
+                .parse()?;
 
             buffer.resize(message_len, 0);
             stdout.read_exact(&mut buffer).await?;
@@ -412,8 +383,6 @@ impl LanguageServer {
             // Don't starve the main thread when receiving lots of messages at once.
             smol::future::yield_now().await;
         }
-
-        Ok(())
     }
 
     async fn handle_stderr<Stderr>(
@@ -1254,6 +1223,26 @@ impl FakeLanguageServer {
     }
 }
 
+pub(self) async fn read_headers<Stdout>(
+    reader: &mut BufReader<Stdout>,
+    buffer: &mut Vec<u8>,
+) -> Result<()>
+where
+    Stdout: AsyncRead + Unpin + Send + 'static,
+{
+    loop {
+        if buffer.len() >= HEADER_DELIMITER.len()
+            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
+        {
+            return Ok(());
+        }
+
+        if reader.read_until(b'\n', buffer).await? == 0 {
+            return Err(anyhow!("cannot read LSP message headers"));
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1327,4 +1316,28 @@ mod tests {
         drop(server);
         fake.receive_notification::<notification::Exit>().await;
     }
+
+    #[gpui::test]
+    async fn test_read_headers() {
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
+
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(
+            buf,
+            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
+        );
+
+        let mut buf = Vec::new();
+        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
+        read_headers(&mut reader, &mut buf).await.unwrap();
+        assert_eq!(
+            buf,
+            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
+        );
+    }
 }