lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    executor: Arc<executor::Background>,
 40    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 41    initialized: barrier::Receiver,
 42    output_done_rx: Mutex<Option<barrier::Receiver>>,
 43}
 44
 45pub struct Subscription {
 46    method: &'static str,
 47    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 48}
 49
 50#[derive(Serialize, Deserialize)]
 51struct Request<'a, T> {
 52    jsonrpc: &'a str,
 53    id: usize,
 54    method: &'a str,
 55    params: T,
 56}
 57
 58#[derive(Serialize, Deserialize)]
 59struct AnyResponse<'a> {
 60    id: usize,
 61    #[serde(default)]
 62    error: Option<Error>,
 63    #[serde(borrow)]
 64    result: Option<&'a RawValue>,
 65}
 66
 67#[derive(Serialize, Deserialize)]
 68struct Notification<'a, T> {
 69    #[serde(borrow)]
 70    jsonrpc: &'a str,
 71    #[serde(borrow)]
 72    method: &'a str,
 73    params: T,
 74}
 75
 76#[derive(Deserialize)]
 77struct AnyNotification<'a> {
 78    #[serde(borrow)]
 79    method: &'a str,
 80    #[serde(borrow)]
 81    params: &'a RawValue,
 82}
 83
 84#[derive(Debug, Serialize, Deserialize)]
 85struct Error {
 86    message: String,
 87}
 88
 89impl LanguageServer {
 90    pub fn new(
 91        binary_path: &Path,
 92        root_path: &Path,
 93        background: Arc<executor::Background>,
 94    ) -> Result<Arc<Self>> {
 95        let mut server = Command::new(binary_path)
 96            .stdin(Stdio::piped())
 97            .stdout(Stdio::piped())
 98            .stderr(Stdio::inherit())
 99            .spawn()?;
100        let stdin = server.stdin.take().unwrap();
101        let stdout = server.stdout.take().unwrap();
102        Self::new_internal(stdin, stdout, root_path, background)
103    }
104
105    fn new_internal<Stdin, Stdout>(
106        stdin: Stdin,
107        stdout: Stdout,
108        root_path: &Path,
109        executor: Arc<executor::Background>,
110    ) -> Result<Arc<Self>>
111    where
112        Stdin: AsyncWrite + Unpin + Send + 'static,
113        Stdout: AsyncRead + Unpin + Send + 'static,
114    {
115        let mut stdin = BufWriter::new(stdin);
116        let mut stdout = BufReader::new(stdout);
117        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120        let input_task = executor.spawn(
121            {
122                let notification_handlers = notification_handlers.clone();
123                let response_handlers = response_handlers.clone();
124                async move {
125                    let mut buffer = Vec::new();
126                    loop {
127                        buffer.clear();
128                        stdout.read_until(b'\n', &mut buffer).await?;
129                        stdout.read_until(b'\n', &mut buffer).await?;
130                        let message_len: usize = std::str::from_utf8(&buffer)?
131                            .strip_prefix(CONTENT_LEN_HEADER)
132                            .ok_or_else(|| anyhow!("invalid header"))?
133                            .trim_end()
134                            .parse()?;
135
136                        buffer.resize(message_len, 0);
137                        stdout.read_exact(&mut buffer).await?;
138
139                        println!("{}", std::str::from_utf8(&buffer).unwrap());
140                        if let Ok(AnyNotification { method, params }) =
141                            serde_json::from_slice(&buffer)
142                        {
143                            if let Some(handler) = notification_handlers.read().get(method) {
144                                handler(params.get());
145                            } else {
146                                log::info!(
147                                    "unhandled notification {}:\n{}",
148                                    method,
149                                    serde_json::to_string_pretty(
150                                        &Value::from_str(params.get()).unwrap()
151                                    )
152                                    .unwrap()
153                                );
154                            }
155                        } else if let Ok(AnyResponse { id, error, result }) =
156                            serde_json::from_slice(&buffer)
157                        {
158                            if let Some(handler) = response_handlers.lock().remove(&id) {
159                                if let Some(error) = error {
160                                    handler(Err(error));
161                                } else if let Some(result) = result {
162                                    handler(Ok(result.get()));
163                                } else {
164                                    handler(Ok("null"));
165                                }
166                            }
167                        } else {
168                            return Err(anyhow!(
169                                "failed to deserialize message:\n{}",
170                                std::str::from_utf8(&buffer)?
171                            ));
172                        }
173                    }
174                }
175            }
176            .log_err(),
177        );
178        let (output_done_tx, output_done_rx) = barrier::channel();
179        let output_task = executor.spawn(
180            async move {
181                let mut content_len_buffer = Vec::new();
182                while let Ok(message) = outbound_rx.recv().await {
183                    content_len_buffer.clear();
184                    write!(content_len_buffer, "{}", message.len()).unwrap();
185                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
186                    stdin.write_all(&content_len_buffer).await?;
187                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
188                    stdin.write_all(&message).await?;
189                    stdin.flush().await?;
190                }
191                drop(output_done_tx);
192                Ok(())
193            }
194            .log_err(),
195        );
196
197        let (initialized_tx, initialized_rx) = barrier::channel();
198        let this = Arc::new(Self {
199            notification_handlers,
200            response_handlers,
201            next_id: Default::default(),
202            outbound_tx: RwLock::new(Some(outbound_tx)),
203            executor: executor.clone(),
204            io_tasks: Mutex::new(Some((input_task, output_task))),
205            initialized: initialized_rx,
206            output_done_rx: Mutex::new(Some(output_done_rx)),
207        });
208
209        let root_uri =
210            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
211        executor
212            .spawn({
213                let this = this.clone();
214                async move {
215                    this.init(root_uri).log_err().await;
216                    drop(initialized_tx);
217                }
218            })
219            .detach();
220
221        Ok(this)
222    }
223
224    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
225        #[allow(deprecated)]
226        let params = lsp_types::InitializeParams {
227            process_id: Default::default(),
228            root_path: Default::default(),
229            root_uri: Some(root_uri),
230            initialization_options: Default::default(),
231            capabilities: lsp_types::ClientCapabilities {
232                experimental: Some(json!({
233                    "serverStatusNotification": true,
234                })),
235                ..Default::default()
236            },
237            trace: Default::default(),
238            workspace_folders: Default::default(),
239            client_info: Default::default(),
240            locale: Default::default(),
241        };
242
243        let this = self.clone();
244        let request = Self::request_internal::<lsp_types::request::Initialize>(
245            &this.next_id,
246            &this.response_handlers,
247            this.outbound_tx.read().as_ref(),
248            params,
249        );
250        request.await?;
251        Self::notify_internal::<lsp_types::notification::Initialized>(
252            this.outbound_tx.read().as_ref(),
253            lsp_types::InitializedParams {},
254        )?;
255        Ok(())
256    }
257
258    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
259        if let Some(tasks) = self.io_tasks.lock().take() {
260            let response_handlers = self.response_handlers.clone();
261            let outbound_tx = self.outbound_tx.write().take();
262            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
263            let mut output_done = self.output_done_rx.lock().take().unwrap();
264            Some(async move {
265                Self::request_internal::<lsp_types::request::Shutdown>(
266                    &next_id,
267                    &response_handlers,
268                    outbound_tx.as_ref(),
269                    (),
270                )
271                .await?;
272                Self::notify_internal::<lsp_types::notification::Exit>(outbound_tx.as_ref(), ())?;
273                drop(outbound_tx);
274                output_done.recv().await;
275                drop(tasks);
276                Ok(())
277            })
278        } else {
279            None
280        }
281    }
282
283    pub fn on_notification<T, F>(&self, f: F) -> Subscription
284    where
285        T: lsp_types::notification::Notification,
286        F: 'static + Send + Sync + Fn(T::Params),
287    {
288        let prev_handler = self.notification_handlers.write().insert(
289            T::METHOD,
290            Box::new(
291                move |notification| match serde_json::from_str(notification) {
292                    Ok(notification) => f(notification),
293                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
294                },
295            ),
296        );
297
298        assert!(
299            prev_handler.is_none(),
300            "registered multiple handlers for the same notification"
301        );
302
303        Subscription {
304            method: T::METHOD,
305            notification_handlers: self.notification_handlers.clone(),
306        }
307    }
308
309    pub fn request<T: lsp_types::request::Request>(
310        self: Arc<Self>,
311        params: T::Params,
312    ) -> impl Future<Output = Result<T::Result>>
313    where
314        T::Result: 'static + Send,
315    {
316        let this = self.clone();
317        async move {
318            this.initialized.clone().recv().await;
319            Self::request_internal::<T>(
320                &this.next_id,
321                &this.response_handlers,
322                this.outbound_tx.read().as_ref(),
323                params,
324            )
325            .await
326        }
327    }
328
329    fn request_internal<T: lsp_types::request::Request>(
330        next_id: &AtomicUsize,
331        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
332        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
333        params: T::Params,
334    ) -> impl 'static + Future<Output = Result<T::Result>>
335    where
336        T::Result: 'static + Send,
337    {
338        let id = next_id.fetch_add(1, SeqCst);
339        let message = serde_json::to_vec(&Request {
340            jsonrpc: JSON_RPC_VERSION,
341            id,
342            method: T::METHOD,
343            params,
344        })
345        .unwrap();
346        let mut response_handlers = response_handlers.lock();
347        let (mut tx, mut rx) = oneshot::channel();
348        response_handlers.insert(
349            id,
350            Box::new(move |result| {
351                let response = match result {
352                    Ok(response) => {
353                        serde_json::from_str(response).context("failed to deserialize response")
354                    }
355                    Err(error) => Err(anyhow!("{}", error.message)),
356                };
357                let _ = tx.try_send(response);
358            }),
359        );
360
361        let send = outbound_tx
362            .as_ref()
363            .ok_or_else(|| {
364                anyhow!("tried to send a request to a language server that has been shut down")
365            })
366            .and_then(|outbound_tx| {
367                outbound_tx.try_send(message)?;
368                Ok(())
369            });
370        async move {
371            send?;
372            rx.recv().await.unwrap()
373        }
374    }
375
376    pub fn notify<T: lsp_types::notification::Notification>(
377        self: &Arc<Self>,
378        params: T::Params,
379    ) -> impl Future<Output = Result<()>> {
380        let this = self.clone();
381        async move {
382            this.initialized.clone().recv().await;
383            Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
384            Ok(())
385        }
386    }
387
388    fn notify_internal<T: lsp_types::notification::Notification>(
389        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
390        params: T::Params,
391    ) -> Result<()> {
392        let message = serde_json::to_vec(&Notification {
393            jsonrpc: JSON_RPC_VERSION,
394            method: T::METHOD,
395            params,
396        })
397        .unwrap();
398        let outbound_tx = outbound_tx
399            .as_ref()
400            .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
401        outbound_tx.try_send(message)?;
402        Ok(())
403    }
404}
405
406impl Drop for LanguageServer {
407    fn drop(&mut self) {
408        if let Some(shutdown) = self.shutdown() {
409            self.executor.spawn(shutdown).detach();
410        }
411    }
412}
413
414impl Subscription {
415    pub fn detach(mut self) {
416        self.method = "";
417    }
418}
419
420impl Drop for Subscription {
421    fn drop(&mut self) {
422        self.notification_handlers.write().remove(self.method);
423    }
424}
425
426#[cfg(any(test, feature = "test-support"))]
427pub struct FakeLanguageServer {
428    buffer: Vec<u8>,
429    stdin: smol::io::BufReader<async_pipe::PipeReader>,
430    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
431    pub started: Arc<AtomicBool>,
432}
433
434#[cfg(any(test, feature = "test-support"))]
435pub struct RequestId<T> {
436    id: usize,
437    _type: std::marker::PhantomData<T>,
438}
439
440#[cfg(any(test, feature = "test-support"))]
441impl LanguageServer {
442    pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
443        let stdin = async_pipe::pipe();
444        let stdout = async_pipe::pipe();
445        let mut fake = FakeLanguageServer {
446            stdin: smol::io::BufReader::new(stdin.1),
447            stdout: smol::io::BufWriter::new(stdout.0),
448            buffer: Vec::new(),
449            started: Arc::new(AtomicBool::new(true)),
450        };
451
452        let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
453
454        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
455        fake.respond(init_id, InitializeResult::default()).await;
456        fake.receive_notification::<notification::Initialized>()
457            .await;
458
459        (server, fake)
460    }
461}
462
463#[cfg(any(test, feature = "test-support"))]
464impl FakeLanguageServer {
465    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
466        if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
467            panic!("can't simulate an LSP notification before the server has been started");
468        }
469        let message = serde_json::to_vec(&Notification {
470            jsonrpc: JSON_RPC_VERSION,
471            method: T::METHOD,
472            params,
473        })
474        .unwrap();
475        self.send(message).await;
476    }
477
478    pub async fn respond<'a, T: request::Request>(
479        &mut self,
480        request_id: RequestId<T>,
481        result: T::Result,
482    ) {
483        let result = serde_json::to_string(&result).unwrap();
484        let message = serde_json::to_vec(&AnyResponse {
485            id: request_id.id,
486            error: None,
487            result: Some(&RawValue::from_string(result).unwrap()),
488        })
489        .unwrap();
490        self.send(message).await;
491    }
492
493    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
494        self.receive().await;
495        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
496        assert_eq!(request.method, T::METHOD);
497        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
498        (
499            RequestId {
500                id: request.id,
501                _type: std::marker::PhantomData,
502            },
503            request.params,
504        )
505    }
506
507    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
508        self.receive().await;
509        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
510        assert_eq!(notification.method, T::METHOD);
511        notification.params
512    }
513
514    async fn send(&mut self, message: Vec<u8>) {
515        self.stdout
516            .write_all(CONTENT_LEN_HEADER.as_bytes())
517            .await
518            .unwrap();
519        self.stdout
520            .write_all((format!("{}", message.len())).as_bytes())
521            .await
522            .unwrap();
523        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
524        self.stdout.write_all(&message).await.unwrap();
525        self.stdout.flush().await.unwrap();
526    }
527
528    async fn receive(&mut self) {
529        self.buffer.clear();
530        self.stdin
531            .read_until(b'\n', &mut self.buffer)
532            .await
533            .unwrap();
534        self.stdin
535            .read_until(b'\n', &mut self.buffer)
536            .await
537            .unwrap();
538        let message_len: usize = std::str::from_utf8(&self.buffer)
539            .unwrap()
540            .strip_prefix(CONTENT_LEN_HEADER)
541            .unwrap()
542            .trim_end()
543            .parse()
544            .unwrap();
545        self.buffer.resize(message_len, 0);
546        self.stdin.read_exact(&mut self.buffer).await.unwrap();
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553    use gpui::TestAppContext;
554    use simplelog::SimpleLogger;
555    use unindent::Unindent;
556    use util::test::temp_tree;
557
558    #[gpui::test]
559    async fn test_basic(cx: TestAppContext) {
560        let lib_source = r#"
561            fn fun() {
562                let hello = "world";
563            }
564        "#
565        .unindent();
566        let root_dir = temp_tree(json!({
567            "Cargo.toml": r#"
568                [package]
569                name = "temp"
570                version = "0.1.0"
571                edition = "2018"
572            "#.unindent(),
573            "src": {
574                "lib.rs": &lib_source
575            }
576        }));
577        let lib_file_uri =
578            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
579
580        let server = cx.read(|cx| {
581            LanguageServer::new(
582                Path::new("rust-analyzer"),
583                root_dir.path(),
584                cx.background().clone(),
585            )
586            .unwrap()
587        });
588        server.next_idle_notification().await;
589
590        server
591            .notify::<lsp_types::notification::DidOpenTextDocument>(
592                lsp_types::DidOpenTextDocumentParams {
593                    text_document: lsp_types::TextDocumentItem::new(
594                        lib_file_uri.clone(),
595                        "rust".to_string(),
596                        0,
597                        lib_source,
598                    ),
599                },
600            )
601            .await
602            .unwrap();
603
604        let hover = server
605            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
606                text_document_position_params: lsp_types::TextDocumentPositionParams {
607                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
608                    position: lsp_types::Position::new(1, 21),
609                },
610                work_done_progress_params: Default::default(),
611            })
612            .await
613            .unwrap()
614            .unwrap();
615        assert_eq!(
616            hover.contents,
617            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
618                kind: lsp_types::MarkupKind::Markdown,
619                value: "&str".to_string()
620            })
621        );
622    }
623
624    #[gpui::test]
625    async fn test_fake(cx: TestAppContext) {
626        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
627
628        let (server, mut fake) = LanguageServer::fake(cx.background()).await;
629
630        let (message_tx, message_rx) = channel::unbounded();
631        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
632        server
633            .on_notification::<notification::ShowMessage, _>(move |params| {
634                message_tx.try_send(params).unwrap()
635            })
636            .detach();
637        server
638            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
639                diagnostics_tx.try_send(params).unwrap()
640            })
641            .detach();
642
643        server
644            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
645                text_document: TextDocumentItem::new(
646                    Url::from_str("file://a/b").unwrap(),
647                    "rust".to_string(),
648                    0,
649                    "".to_string(),
650                ),
651            })
652            .await
653            .unwrap();
654        assert_eq!(
655            fake.receive_notification::<notification::DidOpenTextDocument>()
656                .await
657                .text_document
658                .uri
659                .as_str(),
660            "file://a/b"
661        );
662
663        fake.notify::<notification::ShowMessage>(ShowMessageParams {
664            typ: MessageType::ERROR,
665            message: "ok".to_string(),
666        })
667        .await;
668        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
669            uri: Url::from_str("file://b/c").unwrap(),
670            version: Some(5),
671            diagnostics: vec![],
672        })
673        .await;
674        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
675        assert_eq!(
676            diagnostics_rx.recv().await.unwrap().uri.as_str(),
677            "file://b/c"
678        );
679
680        drop(server);
681        let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
682        fake.respond(shutdown_request, ()).await;
683        fake.receive_notification::<lsp_types::notification::Exit>()
684            .await;
685    }
686
687    impl LanguageServer {
688        async fn next_idle_notification(self: &Arc<Self>) {
689            let (tx, rx) = channel::unbounded();
690            let _subscription =
691                self.on_notification::<ServerStatusNotification, _>(move |params| {
692                    if params.quiescent {
693                        tx.try_send(()).unwrap();
694                    }
695                });
696            let _ = rx.recv().await;
697        }
698    }
699
700    pub enum ServerStatusNotification {}
701
702    impl lsp_types::notification::Notification for ServerStatusNotification {
703        type Params = ServerStatusParams;
704        const METHOD: &'static str = "experimental/serverStatus";
705    }
706
707    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
708    pub struct ServerStatusParams {
709        pub quiescent: bool,
710    }
711}