input_handler.rs

  1use std::str;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use futures::{
  7    AsyncBufReadExt, AsyncRead, AsyncReadExt as _,
  8    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  9};
 10use gpui::{BackgroundExecutor, Task};
 11use log::warn;
 12use parking_lot::Mutex;
 13use smol::io::BufReader;
 14
 15use crate::{
 16    AnyNotification, AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, RequestId, ResponseHandler,
 17};
 18
 19const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n";
 20/// Handler for stdout of language server.
 21pub struct LspStdoutHandler {
 22    pub(super) loop_handle: Task<Result<()>>,
 23    pub(super) notifications_channel: UnboundedReceiver<AnyNotification>,
 24}
 25
 26async fn read_headers<Stdout>(reader: &mut BufReader<Stdout>, buffer: &mut Vec<u8>) -> Result<()>
 27where
 28    Stdout: AsyncRead + Unpin + Send + 'static,
 29{
 30    loop {
 31        if buffer.len() >= HEADER_DELIMITER.len()
 32            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
 33        {
 34            return Ok(());
 35        }
 36
 37        if reader.read_until(b'\n', buffer).await? == 0 {
 38            anyhow::bail!("cannot read LSP message headers");
 39        }
 40    }
 41}
 42
 43impl LspStdoutHandler {
 44    pub fn new<Input>(
 45        stdout: Input,
 46        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 47        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 48        cx: BackgroundExecutor,
 49    ) -> Self
 50    where
 51        Input: AsyncRead + Unpin + Send + 'static,
 52    {
 53        let (tx, notifications_channel) = unbounded();
 54        let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers));
 55        Self {
 56            loop_handle,
 57            notifications_channel,
 58        }
 59    }
 60
 61    async fn handler<Input>(
 62        stdout: Input,
 63        notifications_sender: UnboundedSender<AnyNotification>,
 64        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 65        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 66    ) -> anyhow::Result<()>
 67    where
 68        Input: AsyncRead + Unpin + Send + 'static,
 69    {
 70        let mut stdout = BufReader::new(stdout);
 71
 72        let mut buffer = Vec::new();
 73
 74        loop {
 75            buffer.clear();
 76
 77            read_headers(&mut stdout, &mut buffer).await?;
 78
 79            let headers = std::str::from_utf8(&buffer)?;
 80
 81            let message_len = headers
 82                .split('\n')
 83                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
 84                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
 85                .with_context(|| format!("invalid LSP message header {headers:?}"))?
 86                .trim_end()
 87                .parse()?;
 88
 89            buffer.resize(message_len, 0);
 90            stdout.read_exact(&mut buffer).await?;
 91
 92            if let Ok(message) = str::from_utf8(&buffer) {
 93                log::trace!("incoming message: {message}");
 94                for handler in io_handlers.lock().values_mut() {
 95                    handler(IoKind::StdOut, message);
 96                }
 97            }
 98
 99            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
100                notifications_sender.unbounded_send(msg)?;
101            } else if let Ok(AnyResponse {
102                id, error, result, ..
103            }) = serde_json::from_slice(&buffer)
104            {
105                let mut response_handlers = response_handlers.lock();
106                if let Some(handler) = response_handlers
107                    .as_mut()
108                    .and_then(|handlers| handlers.remove(&id))
109                {
110                    drop(response_handlers);
111                    if let Some(error) = error {
112                        handler(Err(error));
113                    } else if let Some(result) = result {
114                        handler(Ok(result.get().into()));
115                    } else {
116                        handler(Ok("null".into()));
117                    }
118                }
119            } else {
120                warn!(
121                    "failed to deserialize LSP message:\n{}",
122                    std::str::from_utf8(&buffer)?
123                );
124            }
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[gpui::test]
134    async fn test_read_headers() {
135        let mut buf = Vec::new();
136        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
137        read_headers(&mut reader, &mut buf).await.unwrap();
138        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
139
140        let mut buf = Vec::new();
141        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
142        read_headers(&mut reader, &mut buf).await.unwrap();
143        assert_eq!(
144            buf,
145            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
146        );
147
148        let mut buf = Vec::new();
149        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
150        read_headers(&mut reader, &mut buf).await.unwrap();
151        assert_eq!(
152            buf,
153            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
154        );
155    }
156}