input_handler.rs

  1use std::str;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use collections::HashMap;
  6use futures::{
  7    AsyncBufReadExt, AsyncRead, AsyncReadExt as _,
  8    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  9};
 10use gpui::{BackgroundExecutor, Task};
 11use log::warn;
 12use parking_lot::Mutex;
 13use smol::io::BufReader;
 14
 15use crate::{
 16    AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, NotificationOrRequest, RequestId,
 17    ResponseHandler,
 18};
 19
 20const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n";
 21/// Handler for stdout of language server.
 22pub struct LspStdoutHandler {
 23    pub(super) loop_handle: Task<Result<()>>,
 24    pub(super) incoming_messages: UnboundedReceiver<NotificationOrRequest>,
 25}
 26
 27async fn read_headers<Stdout>(reader: &mut BufReader<Stdout>, buffer: &mut Vec<u8>) -> Result<()>
 28where
 29    Stdout: AsyncRead + Unpin + Send + 'static,
 30{
 31    loop {
 32        if buffer.len() >= HEADER_DELIMITER.len()
 33            && buffer[(buffer.len() - HEADER_DELIMITER.len())..] == HEADER_DELIMITER[..]
 34        {
 35            return Ok(());
 36        }
 37
 38        if reader.read_until(b'\n', buffer).await? == 0 {
 39            anyhow::bail!("cannot read LSP message headers");
 40        }
 41    }
 42}
 43
 44impl LspStdoutHandler {
 45    pub fn new<Input>(
 46        stdout: Input,
 47        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 48        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 49        cx: BackgroundExecutor,
 50    ) -> Self
 51    where
 52        Input: AsyncRead + Unpin + Send + 'static,
 53    {
 54        let (tx, notifications_channel) = unbounded();
 55        let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers));
 56        Self {
 57            loop_handle,
 58            incoming_messages: notifications_channel,
 59        }
 60    }
 61
 62    async fn handler<Input>(
 63        stdout: Input,
 64        notifications_sender: UnboundedSender<NotificationOrRequest>,
 65        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 66        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 67    ) -> anyhow::Result<()>
 68    where
 69        Input: AsyncRead + Unpin + Send + 'static,
 70    {
 71        let mut stdout = BufReader::new(stdout);
 72
 73        let mut buffer = Vec::new();
 74
 75        loop {
 76            buffer.clear();
 77
 78            read_headers(&mut stdout, &mut buffer).await?;
 79
 80            let headers = std::str::from_utf8(&buffer)?;
 81
 82            let message_len = headers
 83                .split('\n')
 84                .find(|line| line.starts_with(CONTENT_LEN_HEADER))
 85                .and_then(|line| line.strip_prefix(CONTENT_LEN_HEADER))
 86                .with_context(|| format!("invalid LSP message header {headers:?}"))?
 87                .trim_end()
 88                .parse()?;
 89
 90            buffer.resize(message_len, 0);
 91            stdout.read_exact(&mut buffer).await?;
 92
 93            if let Ok(message) = str::from_utf8(&buffer) {
 94                log::trace!("incoming message: {message}");
 95                for handler in io_handlers.lock().values_mut() {
 96                    handler(IoKind::StdOut, message);
 97                }
 98            }
 99
100            if let Ok(msg) = serde_json::from_slice::<NotificationOrRequest>(&buffer) {
101                notifications_sender.unbounded_send(msg)?;
102            } else if let Ok(AnyResponse {
103                id, error, result, ..
104            }) = serde_json::from_slice(&buffer)
105            {
106                let mut response_handlers = response_handlers.lock();
107                if let Some(handler) = response_handlers
108                    .as_mut()
109                    .and_then(|handlers| handlers.remove(&id))
110                {
111                    drop(response_handlers);
112                    if let Some(error) = error {
113                        handler(Err(error));
114                    } else if let Some(result) = result {
115                        handler(Ok(result.get().into()));
116                    } else {
117                        handler(Ok("null".into()));
118                    }
119                }
120            } else {
121                warn!(
122                    "failed to deserialize LSP message:\n{}",
123                    std::str::from_utf8(&buffer)?
124                );
125            }
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[gpui::test]
135    async fn test_read_headers() {
136        let mut buf = Vec::new();
137        let mut reader = smol::io::BufReader::new(b"Content-Length: 123\r\n\r\n" as &[u8]);
138        read_headers(&mut reader, &mut buf).await.unwrap();
139        assert_eq!(buf, b"Content-Length: 123\r\n\r\n");
140
141        let mut buf = Vec::new();
142        let mut reader = smol::io::BufReader::new(b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n{\"somecontent\":123}" as &[u8]);
143        read_headers(&mut reader, &mut buf).await.unwrap();
144        assert_eq!(
145            buf,
146            b"Content-Type: application/vscode-jsonrpc\r\nContent-Length: 1235\r\n\r\n"
147        );
148
149        let mut buf = Vec::new();
150        let mut reader = smol::io::BufReader::new(b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n{\"somecontent\":true}" as &[u8]);
151        read_headers(&mut reader, &mut buf).await.unwrap();
152        assert_eq!(
153            buf,
154            b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
155        );
156    }
157}