input_handler.rs

  1use std::str;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use collections::HashMap;
  6use futures::{
  7    channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
  8    AsyncBufReadExt, AsyncRead, AsyncReadExt as _,
  9};
 10use gpui::{BackgroundExecutor, Task};
 11use log::warn;
 12use parking_lot::Mutex;
 13use smol::io::BufReader;
 14
 15use crate::{
 16    AnyNotification, AnyResponse, IoHandler, IoKind, RequestId, ResponseHandler, CONTENT_LEN_HEADER,
 17};
 18
 19const HEADER_DELIMITER: &'static [u8; 4] = b"\r\n\r\n";
 20/// Handler for stdout of language server.
 21pub struct LspStdoutHandler {
 22    pub(super) loop_handle: Task<Result<()>>,
 23    pub(super) notifications_channel: UnboundedReceiver<AnyNotification>,
 24}
 25
 26pub(self) async fn read_headers<Stdout>(
 27    reader: &mut BufReader<Stdout>,
 28    buffer: &mut Vec<u8>,
 29) -> Result<()>
 30where
 31    Stdout: AsyncRead + Unpin + Send + 'static,
 32{
 33    loop {
 34        if buffer.len() >= HEADER_DELIMITER.len()
 35            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
 36        {
 37            return Ok(());
 38        }
 39
 40        if reader.read_until(b'\n', buffer).await? == 0 {
 41            return Err(anyhow!("cannot read LSP message headers"));
 42        }
 43    }
 44}
 45
 46impl LspStdoutHandler {
 47    pub fn new<Input>(
 48        stdout: Input,
 49        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 50        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 51        cx: BackgroundExecutor,
 52    ) -> Self
 53    where
 54        Input: AsyncRead + Unpin + Send + 'static,
 55    {
 56        let (tx, notifications_channel) = unbounded();
 57        let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers));
 58        Self {
 59            loop_handle,
 60            notifications_channel,
 61        }
 62    }
 63
 64    async fn handler<Input>(
 65        stdout: Input,
 66        notifications_sender: UnboundedSender<AnyNotification>,
 67        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 68        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 69    ) -> anyhow::Result<()>
 70    where
 71        Input: AsyncRead + Unpin + Send + 'static,
 72    {
 73        let mut stdout = BufReader::new(stdout);
 74
 75        let mut buffer = Vec::new();
 76
 77        loop {
 78            buffer.clear();
 79
 80            read_headers(&mut stdout, &mut buffer).await?;
 81
 82            let headers = std::str::from_utf8(&buffer)?;
 83
 84            let message_len = headers
 85                .split('\n')
 86                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
 87                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
 88                .ok_or_else(|| anyhow!("invalid LSP message header {headers:?}"))?
 89                .trim_end()
 90                .parse()?;
 91
 92            buffer.resize(message_len, 0);
 93            stdout.read_exact(&mut buffer).await?;
 94
 95            if let Ok(message) = str::from_utf8(&buffer) {
 96                log::trace!("incoming message: {message}");
 97                for handler in io_handlers.lock().values_mut() {
 98                    handler(IoKind::StdOut, message);
 99                }
100            }
101
102            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
103                notifications_sender.unbounded_send(msg)?;
104            } else if let Ok(AnyResponse {
105                id, error, result, ..
106            }) = serde_json::from_slice(&buffer)
107            {
108                let mut response_handlers = response_handlers.lock();
109                if let Some(handler) = response_handlers
110                    .as_mut()
111                    .and_then(|handlers| handlers.remove(&id))
112                {
113                    drop(response_handlers);
114                    if let Some(error) = error {
115                        handler(Err(error));
116                    } else if let Some(result) = result {
117                        handler(Ok(result.get().into()));
118                    } else {
119                        handler(Ok("null".into()));
120                    }
121                }
122            } else {
123                warn!(
124                    "failed to deserialize LSP message:\n{}",
125                    std::str::from_utf8(&buffer)?
126                );
127            }
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[gpui::test]
137    async fn test_read_headers() {
138        let mut buf = Vec::new();
139        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
140        read_headers(&mut reader, &mut buf).await.unwrap();
141        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
142
143        let mut buf = Vec::new();
144        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
145        read_headers(&mut reader, &mut buf).await.unwrap();
146        assert_eq!(
147            buf,
148            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
149        );
150
151        let mut buf = Vec::new();
152        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
153        read_headers(&mut reader, &mut buf).await.unwrap();
154        assert_eq!(
155            buf,
156            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
157        );
158    }
159}