lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use gpui::{executor, Task};
  3use parking_lot::Mutex;
  4use serde::{Deserialize, Serialize};
  5use serde_json::value::RawValue;
  6use smol::{
  7    channel,
  8    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
  9    process::Command,
 10};
 11use std::{
 12    collections::HashMap,
 13    future::Future,
 14    io::Write,
 15    sync::{
 16        atomic::{AtomicUsize, Ordering::SeqCst},
 17        Arc,
 18    },
 19};
 20use std::{path::Path, process::Stdio};
 21use util::TryFutureExt;
 22
 23const JSON_RPC_VERSION: &'static str = "2.0";
 24const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 25
 26pub struct LanguageServer {
 27    next_id: AtomicUsize,
 28    outbound_tx: channel::Sender<Vec<u8>>,
 29    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 30    _input_task: Task<Option<()>>,
 31    _output_task: Task<Option<()>>,
 32}
 33
 34type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 35
 36#[derive(Serialize)]
 37struct Request<T> {
 38    jsonrpc: &'static str,
 39    id: usize,
 40    method: &'static str,
 41    params: T,
 42}
 43
 44#[derive(Deserialize)]
 45struct Response<'a> {
 46    id: usize,
 47    #[serde(default)]
 48    error: Option<Error>,
 49    #[serde(default, borrow)]
 50    result: Option<&'a RawValue>,
 51}
 52
 53#[derive(Serialize)]
 54struct OutboundNotification<T> {
 55    jsonrpc: &'static str,
 56    method: &'static str,
 57    params: T,
 58}
 59
 60#[derive(Deserialize)]
 61struct InboundNotification<'a> {
 62    #[serde(borrow)]
 63    method: &'a str,
 64    #[serde(borrow)]
 65    params: &'a RawValue,
 66}
 67
 68#[derive(Deserialize)]
 69struct Error {
 70    message: String,
 71}
 72
 73impl LanguageServer {
 74    pub fn new(path: &Path, background: &executor::Background) -> Result<Arc<Self>> {
 75        let mut server = Command::new(path)
 76            .stdin(Stdio::piped())
 77            .stdout(Stdio::piped())
 78            .stderr(Stdio::inherit())
 79            .spawn()?;
 80        let mut stdin = server.stdin.take().unwrap();
 81        let mut stdout = BufReader::new(server.stdout.take().unwrap());
 82        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
 83        let response_handlers = Arc::new(Mutex::new(HashMap::<usize, ResponseHandler>::new()));
 84        let _input_task = background.spawn(
 85            {
 86                let response_handlers = response_handlers.clone();
 87                async move {
 88                    let mut buffer = Vec::new();
 89                    loop {
 90                        buffer.clear();
 91
 92                        stdout.read_until(b'\n', &mut buffer).await?;
 93                        stdout.read_until(b'\n', &mut buffer).await?;
 94                        let message_len: usize = std::str::from_utf8(&buffer)?
 95                            .strip_prefix(CONTENT_LEN_HEADER)
 96                            .ok_or_else(|| anyhow!("invalid header"))?
 97                            .trim_end()
 98                            .parse()?;
 99
100                        buffer.resize(message_len, 0);
101                        stdout.read_exact(&mut buffer).await?;
102                        if let Ok(InboundNotification { .. }) = serde_json::from_slice(&buffer) {
103                        } else if let Ok(Response { id, error, result }) =
104                            serde_json::from_slice(&buffer)
105                        {
106                            if let Some(handler) = response_handlers.lock().remove(&id) {
107                                if let Some(result) = result {
108                                    handler(Ok(result.get()));
109                                } else if let Some(error) = error {
110                                    handler(Err(error));
111                                }
112                            }
113                        } else {
114                            return Err(anyhow!(
115                                "failed to deserialize message:\n{}",
116                                std::str::from_utf8(&buffer)?
117                            ));
118                        }
119                    }
120                }
121            }
122            .log_err(),
123        );
124        let _output_task = background.spawn(
125            async move {
126                let mut content_len_buffer = Vec::new();
127                loop {
128                    let message = outbound_rx.recv().await?;
129                    write!(content_len_buffer, "{}", message.len()).unwrap();
130                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
131                    stdin.write_all(&content_len_buffer).await?;
132                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
133                    stdin.write_all(&message).await?;
134                }
135            }
136            .log_err(),
137        );
138
139        let this = Arc::new(Self {
140            response_handlers,
141            next_id: Default::default(),
142            outbound_tx,
143            _input_task,
144            _output_task,
145        });
146        let init = this.clone().init();
147        background
148            .spawn(async move {
149                init.log_err().await;
150            })
151            .detach();
152
153        Ok(this)
154    }
155
156    async fn init(self: Arc<Self>) -> Result<()> {
157        self.request::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
158            process_id: Default::default(),
159            root_path: Default::default(),
160            root_uri: Default::default(),
161            initialization_options: Default::default(),
162            capabilities: Default::default(),
163            trace: Default::default(),
164            workspace_folders: Default::default(),
165            client_info: Default::default(),
166            locale: Default::default(),
167        })
168        .await?;
169        self.notify::<lsp_types::notification::Initialized>(lsp_types::InitializedParams {})?;
170        Ok(())
171    }
172
173    pub fn request<T: lsp_types::request::Request>(
174        self: &Arc<Self>,
175        params: T::Params,
176    ) -> impl Future<Output = Result<T::Result>>
177    where
178        T::Result: 'static + Send,
179    {
180        let id = self.next_id.fetch_add(1, SeqCst);
181        let message = serde_json::to_vec(&Request {
182            jsonrpc: JSON_RPC_VERSION,
183            id,
184            method: T::METHOD,
185            params,
186        })
187        .unwrap();
188        let mut response_handlers = self.response_handlers.lock();
189        let (tx, rx) = smol::channel::bounded(1);
190        response_handlers.insert(
191            id,
192            Box::new(move |result| {
193                let response = match result {
194                    Ok(response) => {
195                        serde_json::from_str(response).context("failed to deserialize response")
196                    }
197                    Err(error) => Err(anyhow!("{}", error.message)),
198                };
199                let _ = smol::block_on(tx.send(response));
200            }),
201        );
202
203        let outbound_tx = self.outbound_tx.clone();
204        async move {
205            outbound_tx.send(message).await?;
206            rx.recv().await?
207        }
208    }
209
210    pub fn notify<T: lsp_types::notification::Notification>(
211        &self,
212        params: T::Params,
213    ) -> Result<()> {
214        let message = serde_json::to_vec(&OutboundNotification {
215            jsonrpc: JSON_RPC_VERSION,
216            method: T::METHOD,
217            params,
218        })
219        .unwrap();
220        smol::block_on(self.outbound_tx.send(message))?;
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use gpui::TestAppContext;
229
230    #[gpui::test]
231    async fn test_basic(cx: TestAppContext) {
232        let server = LanguageServer::new();
233    }
234}