lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: channel::Sender<Vec<u8>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    executor: Arc<executor::Background>,
 40    io_tasks: Option<(Task<Option<()>>, Task<Option<()>>)>,
 41    initialized: barrier::Receiver,
 42    output_done_rx: Option<barrier::Receiver>,
 43}
 44
 45pub struct Subscription {
 46    method: &'static str,
 47    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 48}
 49
 50#[derive(Serialize, Deserialize)]
 51struct Request<'a, T> {
 52    jsonrpc: &'a str,
 53    id: usize,
 54    method: &'a str,
 55    params: T,
 56}
 57
 58#[derive(Serialize, Deserialize)]
 59struct AnyResponse<'a> {
 60    id: usize,
 61    #[serde(default)]
 62    error: Option<Error>,
 63    #[serde(borrow)]
 64    result: Option<&'a RawValue>,
 65}
 66
 67#[derive(Serialize, Deserialize)]
 68struct Notification<'a, T> {
 69    #[serde(borrow)]
 70    jsonrpc: &'a str,
 71    #[serde(borrow)]
 72    method: &'a str,
 73    params: T,
 74}
 75
 76#[derive(Deserialize)]
 77struct AnyNotification<'a> {
 78    #[serde(borrow)]
 79    method: &'a str,
 80    #[serde(borrow)]
 81    params: &'a RawValue,
 82}
 83
 84#[derive(Debug, Serialize, Deserialize)]
 85struct Error {
 86    message: String,
 87}
 88
 89impl LanguageServer {
 90    pub fn new(
 91        binary_path: &Path,
 92        root_path: &Path,
 93        background: Arc<executor::Background>,
 94    ) -> Result<Arc<Self>> {
 95        let mut server = Command::new(binary_path)
 96            .stdin(Stdio::piped())
 97            .stdout(Stdio::piped())
 98            .stderr(Stdio::inherit())
 99            .spawn()?;
100        let stdin = server.stdin.take().unwrap();
101        let stdout = server.stdout.take().unwrap();
102        Self::new_internal(stdin, stdout, root_path, background)
103    }
104
105    fn new_internal<Stdin, Stdout>(
106        stdin: Stdin,
107        stdout: Stdout,
108        root_path: &Path,
109        executor: Arc<executor::Background>,
110    ) -> Result<Arc<Self>>
111    where
112        Stdin: AsyncWrite + Unpin + Send + 'static,
113        Stdout: AsyncRead + Unpin + Send + 'static,
114    {
115        let mut stdin = BufWriter::new(stdin);
116        let mut stdout = BufReader::new(stdout);
117        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120        let input_task = executor.spawn(
121            {
122                let notification_handlers = notification_handlers.clone();
123                let response_handlers = response_handlers.clone();
124                async move {
125                    let mut buffer = Vec::new();
126                    loop {
127                        buffer.clear();
128                        stdout.read_until(b'\n', &mut buffer).await?;
129                        stdout.read_until(b'\n', &mut buffer).await?;
130                        let message_len: usize = std::str::from_utf8(&buffer)?
131                            .strip_prefix(CONTENT_LEN_HEADER)
132                            .ok_or_else(|| anyhow!("invalid header"))?
133                            .trim_end()
134                            .parse()?;
135
136                        buffer.resize(message_len, 0);
137                        stdout.read_exact(&mut buffer).await?;
138
139                        if let Ok(AnyNotification { method, params }) =
140                            serde_json::from_slice(&buffer)
141                        {
142                            if let Some(handler) = notification_handlers.read().get(method) {
143                                handler(params.get());
144                            } else {
145                                log::info!(
146                                    "unhandled notification {}:\n{}",
147                                    method,
148                                    serde_json::to_string_pretty(
149                                        &Value::from_str(params.get()).unwrap()
150                                    )
151                                    .unwrap()
152                                );
153                            }
154                        } else if let Ok(AnyResponse { id, error, result }) =
155                            serde_json::from_slice(&buffer)
156                        {
157                            if let Some(handler) = response_handlers.lock().remove(&id) {
158                                if let Some(error) = error {
159                                    handler(Err(error));
160                                } else if let Some(result) = result {
161                                    handler(Ok(result.get()));
162                                } else {
163                                    handler(Ok("null"));
164                                }
165                            }
166                        } else {
167                            return Err(anyhow!(
168                                "failed to deserialize message:\n{}",
169                                std::str::from_utf8(&buffer)?
170                            ));
171                        }
172                    }
173                }
174            }
175            .log_err(),
176        );
177        let (output_done_tx, output_done_rx) = barrier::channel();
178        let output_task = executor.spawn(
179            async move {
180                let mut content_len_buffer = Vec::new();
181                while let Ok(message) = outbound_rx.recv().await {
182                    content_len_buffer.clear();
183                    write!(content_len_buffer, "{}", message.len()).unwrap();
184                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
185                    stdin.write_all(&content_len_buffer).await?;
186                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
187                    stdin.write_all(&message).await?;
188                    stdin.flush().await?;
189                }
190                drop(output_done_tx);
191                Ok(())
192            }
193            .log_err(),
194        );
195
196        let (initialized_tx, initialized_rx) = barrier::channel();
197        let this = Arc::new(Self {
198            notification_handlers,
199            response_handlers,
200            next_id: Default::default(),
201            outbound_tx,
202            executor: executor.clone(),
203            io_tasks: Some((input_task, output_task)),
204            initialized: initialized_rx,
205            output_done_rx: Some(output_done_rx),
206        });
207
208        let root_uri =
209            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
210        executor
211            .spawn({
212                let this = this.clone();
213                async move {
214                    this.init(root_uri).log_err().await;
215                    drop(initialized_tx);
216                }
217            })
218            .detach();
219
220        Ok(this)
221    }
222
223    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
224        #[allow(deprecated)]
225        let params = lsp_types::InitializeParams {
226            process_id: Default::default(),
227            root_path: Default::default(),
228            root_uri: Some(root_uri),
229            initialization_options: Default::default(),
230            capabilities: lsp_types::ClientCapabilities {
231                experimental: Some(json!({
232                    "serverStatusNotification": true,
233                })),
234                ..Default::default()
235            },
236            trace: Default::default(),
237            workspace_folders: Default::default(),
238            client_info: Default::default(),
239            locale: Default::default(),
240        };
241
242        let this = self.clone();
243        Self::request_internal::<lsp_types::request::Initialize>(
244            &this.next_id,
245            &this.response_handlers,
246            &this.outbound_tx,
247            params,
248        )
249        .await?;
250        Self::notify_internal::<lsp_types::notification::Initialized>(
251            &this.outbound_tx,
252            lsp_types::InitializedParams {},
253        )?;
254        Ok(())
255    }
256
257    pub fn on_notification<T, F>(&self, f: F) -> Subscription
258    where
259        T: lsp_types::notification::Notification,
260        F: 'static + Send + Sync + Fn(T::Params),
261    {
262        let prev_handler = self.notification_handlers.write().insert(
263            T::METHOD,
264            Box::new(
265                move |notification| match serde_json::from_str(notification) {
266                    Ok(notification) => f(notification),
267                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
268                },
269            ),
270        );
271
272        assert!(
273            prev_handler.is_none(),
274            "registered multiple handlers for the same notification"
275        );
276
277        Subscription {
278            method: T::METHOD,
279            notification_handlers: self.notification_handlers.clone(),
280        }
281    }
282
283    pub fn request<T: lsp_types::request::Request>(
284        self: Arc<Self>,
285        params: T::Params,
286    ) -> impl Future<Output = Result<T::Result>>
287    where
288        T::Result: 'static + Send,
289    {
290        let this = self.clone();
291        async move {
292            this.initialized.clone().recv().await;
293            Self::request_internal::<T>(
294                &this.next_id,
295                &this.response_handlers,
296                &this.outbound_tx,
297                params,
298            )
299            .await
300        }
301    }
302
303    fn request_internal<T: lsp_types::request::Request>(
304        next_id: &AtomicUsize,
305        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
306        outbound_tx: &channel::Sender<Vec<u8>>,
307        params: T::Params,
308    ) -> impl Future<Output = Result<T::Result>>
309    where
310        T::Result: 'static + Send,
311    {
312        let id = next_id.fetch_add(1, SeqCst);
313        let message = serde_json::to_vec(&Request {
314            jsonrpc: JSON_RPC_VERSION,
315            id,
316            method: T::METHOD,
317            params,
318        })
319        .unwrap();
320        let mut response_handlers = response_handlers.lock();
321        let (mut tx, mut rx) = oneshot::channel();
322        response_handlers.insert(
323            id,
324            Box::new(move |result| {
325                let response = match result {
326                    Ok(response) => {
327                        serde_json::from_str(response).context("failed to deserialize response")
328                    }
329                    Err(error) => Err(anyhow!("{}", error.message)),
330                };
331                let _ = tx.try_send(response);
332            }),
333        );
334
335        let send = outbound_tx.try_send(message);
336        async move {
337            send?;
338            rx.recv().await.unwrap()
339        }
340    }
341
342    pub fn notify<T: lsp_types::notification::Notification>(
343        self: &Arc<Self>,
344        params: T::Params,
345    ) -> impl Future<Output = Result<()>> {
346        let this = self.clone();
347        async move {
348            this.initialized.clone().recv().await;
349            Self::notify_internal::<T>(&this.outbound_tx, params)?;
350            Ok(())
351        }
352    }
353
354    fn notify_internal<T: lsp_types::notification::Notification>(
355        outbound_tx: &channel::Sender<Vec<u8>>,
356        params: T::Params,
357    ) -> Result<()> {
358        let message = serde_json::to_vec(&Notification {
359            jsonrpc: JSON_RPC_VERSION,
360            method: T::METHOD,
361            params,
362        })
363        .unwrap();
364        outbound_tx.try_send(message)?;
365        Ok(())
366    }
367}
368
369impl Drop for LanguageServer {
370    fn drop(&mut self) {
371        let tasks = self.io_tasks.take();
372        let response_handlers = self.response_handlers.clone();
373        let outbound_tx = self.outbound_tx.clone();
374        let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
375        let mut output_done = self.output_done_rx.take().unwrap();
376        self.executor.spawn_critical(
377            async move {
378                Self::request_internal::<lsp_types::request::Shutdown>(
379                    &next_id,
380                    &response_handlers,
381                    &outbound_tx,
382                    (),
383                )
384                .await?;
385                Self::notify_internal::<lsp_types::notification::Exit>(&outbound_tx, ())?;
386                drop(outbound_tx);
387                output_done.recv().await;
388                drop(tasks);
389                Ok(())
390            }
391            .log_err(),
392        )
393    }
394}
395
396impl Subscription {
397    pub fn detach(mut self) {
398        self.method = "";
399    }
400}
401
402impl Drop for Subscription {
403    fn drop(&mut self) {
404        self.notification_handlers.write().remove(self.method);
405    }
406}
407
408#[cfg(any(test, feature = "test-support"))]
409pub struct FakeLanguageServer {
410    buffer: Vec<u8>,
411    stdin: smol::io::BufReader<async_pipe::PipeReader>,
412    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
413}
414
415#[cfg(any(test, feature = "test-support"))]
416pub struct RequestId<T> {
417    id: usize,
418    _type: std::marker::PhantomData<T>,
419}
420
421#[cfg(any(test, feature = "test-support"))]
422impl LanguageServer {
423    pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
424        let stdin = async_pipe::pipe();
425        let stdout = async_pipe::pipe();
426        let mut fake = FakeLanguageServer {
427            stdin: smol::io::BufReader::new(stdin.1),
428            stdout: smol::io::BufWriter::new(stdout.0),
429            buffer: Vec::new(),
430        };
431
432        let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
433
434        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
435        fake.respond(init_id, InitializeResult::default()).await;
436        fake.receive_notification::<notification::Initialized>()
437            .await;
438
439        (server, fake)
440    }
441}
442
443#[cfg(any(test, feature = "test-support"))]
444impl FakeLanguageServer {
445    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
446        let message = serde_json::to_vec(&Notification {
447            jsonrpc: JSON_RPC_VERSION,
448            method: T::METHOD,
449            params,
450        })
451        .unwrap();
452        self.send(message).await;
453    }
454
455    pub async fn respond<'a, T: request::Request>(
456        &mut self,
457        request_id: RequestId<T>,
458        result: T::Result,
459    ) {
460        let result = serde_json::to_string(&result).unwrap();
461        let message = serde_json::to_vec(&AnyResponse {
462            id: request_id.id,
463            error: None,
464            result: Some(&RawValue::from_string(result).unwrap()),
465        })
466        .unwrap();
467        self.send(message).await;
468    }
469
470    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
471        self.receive().await;
472        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
473        assert_eq!(request.method, T::METHOD);
474        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
475        (
476            RequestId {
477                id: request.id,
478                _type: std::marker::PhantomData,
479            },
480            request.params,
481        )
482    }
483
484    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
485        self.receive().await;
486        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
487        assert_eq!(notification.method, T::METHOD);
488        notification.params
489    }
490
491    async fn send(&mut self, message: Vec<u8>) {
492        self.stdout
493            .write_all(CONTENT_LEN_HEADER.as_bytes())
494            .await
495            .unwrap();
496        self.stdout
497            .write_all((format!("{}", message.len())).as_bytes())
498            .await
499            .unwrap();
500        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
501        self.stdout.write_all(&message).await.unwrap();
502        self.stdout.flush().await.unwrap();
503    }
504
505    async fn receive(&mut self) {
506        self.buffer.clear();
507        self.stdin
508            .read_until(b'\n', &mut self.buffer)
509            .await
510            .unwrap();
511        self.stdin
512            .read_until(b'\n', &mut self.buffer)
513            .await
514            .unwrap();
515        let message_len: usize = std::str::from_utf8(&self.buffer)
516            .unwrap()
517            .strip_prefix(CONTENT_LEN_HEADER)
518            .unwrap()
519            .trim_end()
520            .parse()
521            .unwrap();
522        self.buffer.resize(message_len, 0);
523        self.stdin.read_exact(&mut self.buffer).await.unwrap();
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use gpui::TestAppContext;
531    use simplelog::SimpleLogger;
532    use unindent::Unindent;
533    use util::test::temp_tree;
534
535    #[gpui::test]
536    async fn test_basic(cx: TestAppContext) {
537        let lib_source = r#"
538            fn fun() {
539                let hello = "world";
540            }
541        "#
542        .unindent();
543        let root_dir = temp_tree(json!({
544            "Cargo.toml": r#"
545                [package]
546                name = "temp"
547                version = "0.1.0"
548                edition = "2018"
549            "#.unindent(),
550            "src": {
551                "lib.rs": &lib_source
552            }
553        }));
554        let lib_file_uri =
555            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
556
557        let server = cx.read(|cx| {
558            LanguageServer::new(
559                Path::new("rust-analyzer"),
560                root_dir.path(),
561                cx.background().clone(),
562            )
563            .unwrap()
564        });
565        server.next_idle_notification().await;
566
567        server
568            .notify::<lsp_types::notification::DidOpenTextDocument>(
569                lsp_types::DidOpenTextDocumentParams {
570                    text_document: lsp_types::TextDocumentItem::new(
571                        lib_file_uri.clone(),
572                        "rust".to_string(),
573                        0,
574                        lib_source,
575                    ),
576                },
577            )
578            .await
579            .unwrap();
580
581        let hover = server
582            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
583                text_document_position_params: lsp_types::TextDocumentPositionParams {
584                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
585                    position: lsp_types::Position::new(1, 21),
586                },
587                work_done_progress_params: Default::default(),
588            })
589            .await
590            .unwrap()
591            .unwrap();
592        assert_eq!(
593            hover.contents,
594            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
595                kind: lsp_types::MarkupKind::Markdown,
596                value: "&str".to_string()
597            })
598        );
599    }
600
601    #[gpui::test]
602    async fn test_fake(cx: TestAppContext) {
603        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
604
605        let (server, mut fake) = LanguageServer::fake(cx.background()).await;
606
607        let (message_tx, message_rx) = channel::unbounded();
608        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
609        server
610            .on_notification::<notification::ShowMessage, _>(move |params| {
611                message_tx.try_send(params).unwrap()
612            })
613            .detach();
614        server
615            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
616                diagnostics_tx.try_send(params).unwrap()
617            })
618            .detach();
619
620        server
621            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
622                text_document: TextDocumentItem::new(
623                    Url::from_str("file://a/b").unwrap(),
624                    "rust".to_string(),
625                    0,
626                    "".to_string(),
627                ),
628            })
629            .await
630            .unwrap();
631        assert_eq!(
632            fake.receive_notification::<notification::DidOpenTextDocument>()
633                .await
634                .text_document
635                .uri
636                .as_str(),
637            "file://a/b"
638        );
639
640        fake.notify::<notification::ShowMessage>(ShowMessageParams {
641            typ: MessageType::ERROR,
642            message: "ok".to_string(),
643        })
644        .await;
645        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
646            uri: Url::from_str("file://b/c").unwrap(),
647            version: Some(5),
648            diagnostics: vec![],
649        })
650        .await;
651        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
652        assert_eq!(
653            diagnostics_rx.recv().await.unwrap().uri.as_str(),
654            "file://b/c"
655        );
656
657        drop(server);
658        let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
659        fake.respond(shutdown_request, ()).await;
660        fake.receive_notification::<lsp_types::notification::Exit>()
661            .await;
662    }
663
664    impl LanguageServer {
665        async fn next_idle_notification(self: &Arc<Self>) {
666            let (tx, rx) = channel::unbounded();
667            let _subscription =
668                self.on_notification::<ServerStatusNotification, _>(move |params| {
669                    if params.quiescent {
670                        tx.try_send(()).unwrap();
671                    }
672                });
673            let _ = rx.recv().await;
674        }
675    }
676
677    pub enum ServerStatusNotification {}
678
679    impl lsp_types::notification::Notification for ServerStatusNotification {
680        type Params = ServerStatusParams;
681        const METHOD: &'static str = "experimental/serverStatus";
682    }
683
684    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
685    pub struct ServerStatusParams {
686        pub quiescent: bool,
687    }
688}