lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: channel::Sender<Vec<u8>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    executor: Arc<executor::Background>,
 40    io_tasks: Option<(Task<Option<()>>, Task<Option<()>>)>,
 41    initialized: barrier::Receiver,
 42    output_done_rx: Option<barrier::Receiver>,
 43}
 44
 45pub struct Subscription {
 46    method: &'static str,
 47    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 48}
 49
 50#[derive(Serialize, Deserialize)]
 51struct Request<'a, T> {
 52    jsonrpc: &'a str,
 53    id: usize,
 54    method: &'a str,
 55    params: T,
 56}
 57
 58#[derive(Serialize, Deserialize)]
 59struct AnyResponse<'a> {
 60    id: usize,
 61    #[serde(default)]
 62    error: Option<Error>,
 63    #[serde(borrow)]
 64    result: &'a RawValue,
 65}
 66
 67#[derive(Serialize, Deserialize)]
 68struct Notification<'a, T> {
 69    #[serde(borrow)]
 70    jsonrpc: &'a str,
 71    #[serde(borrow)]
 72    method: &'a str,
 73    params: T,
 74}
 75
 76#[derive(Deserialize)]
 77struct AnyNotification<'a> {
 78    #[serde(borrow)]
 79    method: &'a str,
 80    #[serde(borrow)]
 81    params: &'a RawValue,
 82}
 83
 84#[derive(Debug, Serialize, Deserialize)]
 85struct Error {
 86    message: String,
 87}
 88
 89impl LanguageServer {
 90    pub fn new(
 91        binary_path: &Path,
 92        root_path: &Path,
 93        background: Arc<executor::Background>,
 94    ) -> Result<Arc<Self>> {
 95        let mut server = Command::new(binary_path)
 96            .stdin(Stdio::piped())
 97            .stdout(Stdio::piped())
 98            .stderr(Stdio::inherit())
 99            .spawn()?;
100        let stdin = server.stdin.take().unwrap();
101        let stdout = server.stdout.take().unwrap();
102        Self::new_internal(stdin, stdout, root_path, background)
103    }
104
105    fn new_internal<Stdin, Stdout>(
106        stdin: Stdin,
107        stdout: Stdout,
108        root_path: &Path,
109        executor: Arc<executor::Background>,
110    ) -> Result<Arc<Self>>
111    where
112        Stdin: AsyncWrite + Unpin + Send + 'static,
113        Stdout: AsyncRead + Unpin + Send + 'static,
114    {
115        let mut stdin = BufWriter::new(stdin);
116        let mut stdout = BufReader::new(stdout);
117        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120        let input_task = executor.spawn(
121            {
122                let notification_handlers = notification_handlers.clone();
123                let response_handlers = response_handlers.clone();
124                async move {
125                    let mut buffer = Vec::new();
126                    loop {
127                        buffer.clear();
128                        stdout.read_until(b'\n', &mut buffer).await?;
129                        stdout.read_until(b'\n', &mut buffer).await?;
130                        let message_len: usize = std::str::from_utf8(&buffer)?
131                            .strip_prefix(CONTENT_LEN_HEADER)
132                            .ok_or_else(|| anyhow!("invalid header"))?
133                            .trim_end()
134                            .parse()?;
135
136                        buffer.resize(message_len, 0);
137                        stdout.read_exact(&mut buffer).await?;
138
139                        if let Ok(AnyNotification { method, params }) =
140                            serde_json::from_slice(&buffer)
141                        {
142                            if let Some(handler) = notification_handlers.read().get(method) {
143                                handler(params.get());
144                            } else {
145                                log::info!(
146                                    "unhandled notification {}:\n{}",
147                                    method,
148                                    serde_json::to_string_pretty(
149                                        &Value::from_str(params.get()).unwrap()
150                                    )
151                                    .unwrap()
152                                );
153                            }
154                        } else if let Ok(AnyResponse { id, error, result }) =
155                            serde_json::from_slice(&buffer)
156                        {
157                            if let Some(handler) = response_handlers.lock().remove(&id) {
158                                if let Some(error) = error {
159                                    handler(Err(error));
160                                } else {
161                                    handler(Ok(result.get()));
162                                }
163                            }
164                        } else {
165                            return Err(anyhow!(
166                                "failed to deserialize message:\n{}",
167                                std::str::from_utf8(&buffer)?
168                            ));
169                        }
170                    }
171                }
172            }
173            .log_err(),
174        );
175        let (output_done_tx, output_done_rx) = barrier::channel();
176        let output_task = executor.spawn(
177            async move {
178                let mut content_len_buffer = Vec::new();
179                while let Ok(message) = outbound_rx.recv().await {
180                    content_len_buffer.clear();
181                    write!(content_len_buffer, "{}", message.len()).unwrap();
182                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
183                    stdin.write_all(&content_len_buffer).await?;
184                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
185                    stdin.write_all(&message).await?;
186                    stdin.flush().await?;
187                }
188                drop(output_done_tx);
189                Ok(())
190            }
191            .log_err(),
192        );
193
194        let (initialized_tx, initialized_rx) = barrier::channel();
195        let this = Arc::new(Self {
196            notification_handlers,
197            response_handlers,
198            next_id: Default::default(),
199            outbound_tx,
200            executor: executor.clone(),
201            io_tasks: Some((input_task, output_task)),
202            initialized: initialized_rx,
203            output_done_rx: Some(output_done_rx),
204        });
205
206        let root_uri =
207            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
208        executor
209            .spawn({
210                let this = this.clone();
211                async move {
212                    this.init(root_uri).log_err().await;
213                    drop(initialized_tx);
214                }
215            })
216            .detach();
217
218        Ok(this)
219    }
220
221    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
222        #[allow(deprecated)]
223        let params = lsp_types::InitializeParams {
224            process_id: Default::default(),
225            root_path: Default::default(),
226            root_uri: Some(root_uri),
227            initialization_options: Default::default(),
228            capabilities: lsp_types::ClientCapabilities {
229                experimental: Some(json!({
230                    "serverStatusNotification": true,
231                })),
232                ..Default::default()
233            },
234            trace: Default::default(),
235            workspace_folders: Default::default(),
236            client_info: Default::default(),
237            locale: Default::default(),
238        };
239
240        let this = self.clone();
241        Self::request_internal::<lsp_types::request::Initialize>(
242            &this.next_id,
243            &this.response_handlers,
244            &this.outbound_tx,
245            params,
246        )
247        .await?;
248        Self::notify_internal::<lsp_types::notification::Initialized>(
249            &this.outbound_tx,
250            lsp_types::InitializedParams {},
251        )?;
252        Ok(())
253    }
254
255    pub fn on_notification<T, F>(&self, f: F) -> Subscription
256    where
257        T: lsp_types::notification::Notification,
258        F: 'static + Send + Sync + Fn(T::Params),
259    {
260        let prev_handler = self.notification_handlers.write().insert(
261            T::METHOD,
262            Box::new(
263                move |notification| match serde_json::from_str(notification) {
264                    Ok(notification) => f(notification),
265                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
266                },
267            ),
268        );
269
270        assert!(
271            prev_handler.is_none(),
272            "registered multiple handlers for the same notification"
273        );
274
275        Subscription {
276            method: T::METHOD,
277            notification_handlers: self.notification_handlers.clone(),
278        }
279    }
280
281    pub fn request<T: lsp_types::request::Request>(
282        self: Arc<Self>,
283        params: T::Params,
284    ) -> impl Future<Output = Result<T::Result>>
285    where
286        T::Result: 'static + Send,
287    {
288        let this = self.clone();
289        async move {
290            this.initialized.clone().recv().await;
291            Self::request_internal::<T>(
292                &this.next_id,
293                &this.response_handlers,
294                &this.outbound_tx,
295                params,
296            )
297            .await
298        }
299    }
300
301    fn request_internal<T: lsp_types::request::Request>(
302        next_id: &AtomicUsize,
303        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
304        outbound_tx: &channel::Sender<Vec<u8>>,
305        params: T::Params,
306    ) -> impl Future<Output = Result<T::Result>>
307    where
308        T::Result: 'static + Send,
309    {
310        let id = next_id.fetch_add(1, SeqCst);
311        let message = serde_json::to_vec(&Request {
312            jsonrpc: JSON_RPC_VERSION,
313            id,
314            method: T::METHOD,
315            params,
316        })
317        .unwrap();
318        let mut response_handlers = response_handlers.lock();
319        let (mut tx, mut rx) = oneshot::channel();
320        response_handlers.insert(
321            id,
322            Box::new(move |result| {
323                let response = match result {
324                    Ok(response) => {
325                        serde_json::from_str(response).context("failed to deserialize response")
326                    }
327                    Err(error) => Err(anyhow!("{}", error.message)),
328                };
329                let _ = tx.try_send(response);
330            }),
331        );
332
333        let send = outbound_tx.try_send(message);
334        async move {
335            send?;
336            rx.recv().await.unwrap()
337        }
338    }
339
340    pub fn notify<T: lsp_types::notification::Notification>(
341        self: &Arc<Self>,
342        params: T::Params,
343    ) -> impl Future<Output = Result<()>> {
344        let this = self.clone();
345        async move {
346            this.initialized.clone().recv().await;
347            Self::notify_internal::<T>(&this.outbound_tx, params)?;
348            Ok(())
349        }
350    }
351
352    fn notify_internal<T: lsp_types::notification::Notification>(
353        outbound_tx: &channel::Sender<Vec<u8>>,
354        params: T::Params,
355    ) -> Result<()> {
356        let message = serde_json::to_vec(&Notification {
357            jsonrpc: JSON_RPC_VERSION,
358            method: T::METHOD,
359            params,
360        })
361        .unwrap();
362        outbound_tx.try_send(message)?;
363        Ok(())
364    }
365}
366
367impl Drop for LanguageServer {
368    fn drop(&mut self) {
369        let tasks = self.io_tasks.take();
370        let response_handlers = self.response_handlers.clone();
371        let outbound_tx = self.outbound_tx.clone();
372        let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
373        let mut output_done = self.output_done_rx.take().unwrap();
374        self.executor.spawn_critical(
375            async move {
376                Self::request_internal::<lsp_types::request::Shutdown>(
377                    &next_id,
378                    &response_handlers,
379                    &outbound_tx,
380                    (),
381                )
382                .await?;
383                Self::notify_internal::<lsp_types::notification::Exit>(&outbound_tx, ())?;
384                drop(outbound_tx);
385                output_done.recv().await;
386                drop(tasks);
387                Ok(())
388            }
389            .log_err(),
390        )
391    }
392}
393
394impl Subscription {
395    pub fn detach(mut self) {
396        self.method = "";
397    }
398}
399
400impl Drop for Subscription {
401    fn drop(&mut self) {
402        self.notification_handlers.write().remove(self.method);
403    }
404}
405
406#[cfg(any(test, feature = "test-support"))]
407pub struct FakeLanguageServer {
408    buffer: Vec<u8>,
409    stdin: smol::io::BufReader<async_pipe::PipeReader>,
410    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
411}
412
413#[cfg(any(test, feature = "test-support"))]
414pub struct RequestId<T> {
415    id: usize,
416    _type: std::marker::PhantomData<T>,
417}
418
419#[cfg(any(test, feature = "test-support"))]
420impl LanguageServer {
421    pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
422        let stdin = async_pipe::pipe();
423        let stdout = async_pipe::pipe();
424        let mut fake = FakeLanguageServer {
425            stdin: smol::io::BufReader::new(stdin.1),
426            stdout: smol::io::BufWriter::new(stdout.0),
427            buffer: Vec::new(),
428        };
429
430        let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
431
432        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
433        fake.respond(init_id, InitializeResult::default()).await;
434        fake.receive_notification::<notification::Initialized>()
435            .await;
436
437        (server, fake)
438    }
439}
440
441#[cfg(any(test, feature = "test-support"))]
442impl FakeLanguageServer {
443    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
444        let message = serde_json::to_vec(&Notification {
445            jsonrpc: JSON_RPC_VERSION,
446            method: T::METHOD,
447            params,
448        })
449        .unwrap();
450        self.send(message).await;
451    }
452
453    pub async fn respond<'a, T: request::Request>(
454        &mut self,
455        request_id: RequestId<T>,
456        result: T::Result,
457    ) {
458        let result = serde_json::to_string(&result).unwrap();
459        let message = serde_json::to_vec(&AnyResponse {
460            id: request_id.id,
461            error: None,
462            result: &RawValue::from_string(result).unwrap(),
463        })
464        .unwrap();
465        self.send(message).await;
466    }
467
468    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
469        self.receive().await;
470        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
471        assert_eq!(request.method, T::METHOD);
472        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
473        (
474            RequestId {
475                id: request.id,
476                _type: std::marker::PhantomData,
477            },
478            request.params,
479        )
480    }
481
482    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
483        self.receive().await;
484        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
485        assert_eq!(notification.method, T::METHOD);
486        notification.params
487    }
488
489    async fn send(&mut self, message: Vec<u8>) {
490        self.stdout
491            .write_all(CONTENT_LEN_HEADER.as_bytes())
492            .await
493            .unwrap();
494        self.stdout
495            .write_all((format!("{}", message.len())).as_bytes())
496            .await
497            .unwrap();
498        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
499        self.stdout.write_all(&message).await.unwrap();
500        self.stdout.flush().await.unwrap();
501    }
502
503    async fn receive(&mut self) {
504        self.buffer.clear();
505        self.stdin
506            .read_until(b'\n', &mut self.buffer)
507            .await
508            .unwrap();
509        self.stdin
510            .read_until(b'\n', &mut self.buffer)
511            .await
512            .unwrap();
513        let message_len: usize = std::str::from_utf8(&self.buffer)
514            .unwrap()
515            .strip_prefix(CONTENT_LEN_HEADER)
516            .unwrap()
517            .trim_end()
518            .parse()
519            .unwrap();
520        self.buffer.resize(message_len, 0);
521        self.stdin.read_exact(&mut self.buffer).await.unwrap();
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use gpui::TestAppContext;
529    use simplelog::SimpleLogger;
530    use unindent::Unindent;
531    use util::test::temp_tree;
532
533    #[gpui::test]
534    async fn test_basic(cx: TestAppContext) {
535        let lib_source = r#"
536            fn fun() {
537                let hello = "world";
538            }
539        "#
540        .unindent();
541        let root_dir = temp_tree(json!({
542            "Cargo.toml": r#"
543                [package]
544                name = "temp"
545                version = "0.1.0"
546                edition = "2018"
547            "#.unindent(),
548            "src": {
549                "lib.rs": &lib_source
550            }
551        }));
552        let lib_file_uri =
553            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
554
555        let server = cx.read(|cx| {
556            LanguageServer::new(
557                Path::new("rust-analyzer"),
558                root_dir.path(),
559                cx.background().clone(),
560            )
561            .unwrap()
562        });
563        server.next_idle_notification().await;
564
565        server
566            .notify::<lsp_types::notification::DidOpenTextDocument>(
567                lsp_types::DidOpenTextDocumentParams {
568                    text_document: lsp_types::TextDocumentItem::new(
569                        lib_file_uri.clone(),
570                        "rust".to_string(),
571                        0,
572                        lib_source,
573                    ),
574                },
575            )
576            .await
577            .unwrap();
578
579        let hover = server
580            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
581                text_document_position_params: lsp_types::TextDocumentPositionParams {
582                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
583                    position: lsp_types::Position::new(1, 21),
584                },
585                work_done_progress_params: Default::default(),
586            })
587            .await
588            .unwrap()
589            .unwrap();
590        assert_eq!(
591            hover.contents,
592            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
593                kind: lsp_types::MarkupKind::Markdown,
594                value: "&str".to_string()
595            })
596        );
597    }
598
599    #[gpui::test]
600    async fn test_fake(cx: TestAppContext) {
601        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
602
603        let (server, mut fake) = LanguageServer::fake(cx.background()).await;
604
605        let (message_tx, message_rx) = channel::unbounded();
606        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
607        server
608            .on_notification::<notification::ShowMessage, _>(move |params| {
609                message_tx.try_send(params).unwrap()
610            })
611            .detach();
612        server
613            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
614                diagnostics_tx.try_send(params).unwrap()
615            })
616            .detach();
617
618        server
619            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
620                text_document: TextDocumentItem::new(
621                    Url::from_str("file://a/b").unwrap(),
622                    "rust".to_string(),
623                    0,
624                    "".to_string(),
625                ),
626            })
627            .await
628            .unwrap();
629        assert_eq!(
630            fake.receive_notification::<notification::DidOpenTextDocument>()
631                .await
632                .text_document
633                .uri
634                .as_str(),
635            "file://a/b"
636        );
637
638        fake.notify::<notification::ShowMessage>(ShowMessageParams {
639            typ: MessageType::ERROR,
640            message: "ok".to_string(),
641        })
642        .await;
643        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
644            uri: Url::from_str("file://b/c").unwrap(),
645            version: Some(5),
646            diagnostics: vec![],
647        })
648        .await;
649        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
650        assert_eq!(
651            diagnostics_rx.recv().await.unwrap().uri.as_str(),
652            "file://b/c"
653        );
654
655        drop(server);
656        let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
657        fake.respond(shutdown_request, ()).await;
658        fake.receive_notification::<lsp_types::notification::Exit>()
659            .await;
660    }
661
662    impl LanguageServer {
663        async fn next_idle_notification(self: &Arc<Self>) {
664            let (tx, rx) = channel::unbounded();
665            let _subscription =
666                self.on_notification::<ServerStatusNotification, _>(move |params| {
667                    if params.quiescent {
668                        tx.try_send(()).unwrap();
669                    }
670                });
671            let _ = rx.recv().await;
672        }
673    }
674
675    pub enum ServerStatusNotification {}
676
677    impl lsp_types::notification::Notification for ServerStatusNotification {
678        type Params = ServerStatusParams;
679        const METHOD: &'static str = "experimental/serverStatus";
680    }
681
682    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
683    pub struct ServerStatusParams {
684        pub quiescent: bool,
685    }
686}