lsp.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    executor: Arc<executor::Background>,
 40    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 41    initialized: barrier::Receiver,
 42    output_done_rx: Mutex<Option<barrier::Receiver>>,
 43}
 44
 45pub struct Subscription {
 46    method: &'static str,
 47    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 48}
 49
 50#[derive(Serialize, Deserialize)]
 51struct Request<'a, T> {
 52    jsonrpc: &'a str,
 53    id: usize,
 54    method: &'a str,
 55    params: T,
 56}
 57
 58#[derive(Serialize, Deserialize)]
 59struct AnyResponse<'a> {
 60    id: usize,
 61    #[serde(default)]
 62    error: Option<Error>,
 63    #[serde(borrow)]
 64    result: Option<&'a RawValue>,
 65}
 66
 67#[derive(Serialize, Deserialize)]
 68struct Notification<'a, T> {
 69    #[serde(borrow)]
 70    jsonrpc: &'a str,
 71    #[serde(borrow)]
 72    method: &'a str,
 73    params: T,
 74}
 75
 76#[derive(Deserialize)]
 77struct AnyNotification<'a> {
 78    #[serde(borrow)]
 79    method: &'a str,
 80    #[serde(borrow)]
 81    params: &'a RawValue,
 82}
 83
 84#[derive(Debug, Serialize, Deserialize)]
 85struct Error {
 86    message: String,
 87}
 88
 89impl LanguageServer {
 90    pub fn new(
 91        binary_path: &Path,
 92        root_path: &Path,
 93        background: Arc<executor::Background>,
 94    ) -> Result<Arc<Self>> {
 95        let mut server = Command::new(binary_path)
 96            .stdin(Stdio::piped())
 97            .stdout(Stdio::piped())
 98            .stderr(Stdio::inherit())
 99            .spawn()?;
100        let stdin = server.stdin.take().unwrap();
101        let stdout = server.stdout.take().unwrap();
102        Self::new_internal(stdin, stdout, root_path, background)
103    }
104
105    fn new_internal<Stdin, Stdout>(
106        stdin: Stdin,
107        stdout: Stdout,
108        root_path: &Path,
109        executor: Arc<executor::Background>,
110    ) -> Result<Arc<Self>>
111    where
112        Stdin: AsyncWrite + Unpin + Send + 'static,
113        Stdout: AsyncRead + Unpin + Send + 'static,
114    {
115        let mut stdin = BufWriter::new(stdin);
116        let mut stdout = BufReader::new(stdout);
117        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120        let input_task = executor.spawn(
121            {
122                let notification_handlers = notification_handlers.clone();
123                let response_handlers = response_handlers.clone();
124                async move {
125                    let mut buffer = Vec::new();
126                    loop {
127                        buffer.clear();
128                        stdout.read_until(b'\n', &mut buffer).await?;
129                        stdout.read_until(b'\n', &mut buffer).await?;
130                        let message_len: usize = std::str::from_utf8(&buffer)?
131                            .strip_prefix(CONTENT_LEN_HEADER)
132                            .ok_or_else(|| anyhow!("invalid header"))?
133                            .trim_end()
134                            .parse()?;
135
136                        buffer.resize(message_len, 0);
137                        stdout.read_exact(&mut buffer).await?;
138
139                        if let Ok(AnyNotification { method, params }) =
140                            serde_json::from_slice(&buffer)
141                        {
142                            if let Some(handler) = notification_handlers.write().get_mut(method) {
143                                handler(params.get());
144                            } else {
145                                log::info!(
146                                    "unhandled notification {}:\n{}",
147                                    method,
148                                    serde_json::to_string_pretty(
149                                        &Value::from_str(params.get()).unwrap()
150                                    )
151                                    .unwrap()
152                                );
153                            }
154                        } else if let Ok(AnyResponse { id, error, result }) =
155                            serde_json::from_slice(&buffer)
156                        {
157                            if let Some(handler) = response_handlers.lock().remove(&id) {
158                                if let Some(error) = error {
159                                    handler(Err(error));
160                                } else if let Some(result) = result {
161                                    handler(Ok(result.get()));
162                                } else {
163                                    handler(Ok("null"));
164                                }
165                            }
166                        } else {
167                            return Err(anyhow!(
168                                "failed to deserialize message:\n{}",
169                                std::str::from_utf8(&buffer)?
170                            ));
171                        }
172                    }
173                }
174            }
175            .log_err(),
176        );
177        let (output_done_tx, output_done_rx) = barrier::channel();
178        let output_task = executor.spawn(
179            async move {
180                let mut content_len_buffer = Vec::new();
181                while let Ok(message) = outbound_rx.recv().await {
182                    content_len_buffer.clear();
183                    write!(content_len_buffer, "{}", message.len()).unwrap();
184                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
185                    stdin.write_all(&content_len_buffer).await?;
186                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
187                    stdin.write_all(&message).await?;
188                    stdin.flush().await?;
189                }
190                drop(output_done_tx);
191                Ok(())
192            }
193            .log_err(),
194        );
195
196        let (initialized_tx, initialized_rx) = barrier::channel();
197        let this = Arc::new(Self {
198            notification_handlers,
199            response_handlers,
200            next_id: Default::default(),
201            outbound_tx: RwLock::new(Some(outbound_tx)),
202            executor: executor.clone(),
203            io_tasks: Mutex::new(Some((input_task, output_task))),
204            initialized: initialized_rx,
205            output_done_rx: Mutex::new(Some(output_done_rx)),
206        });
207
208        let root_uri = Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
209        executor
210            .spawn({
211                let this = this.clone();
212                async move {
213                    this.init(root_uri).log_err().await;
214                    drop(initialized_tx);
215                }
216            })
217            .detach();
218
219        Ok(this)
220    }
221
222    async fn init(self: Arc<Self>, root_uri: Url) -> Result<()> {
223        #[allow(deprecated)]
224        let params = InitializeParams {
225            process_id: Default::default(),
226            root_path: Default::default(),
227            root_uri: Some(root_uri),
228            initialization_options: Default::default(),
229            capabilities: ClientCapabilities {
230                text_document: Some(TextDocumentClientCapabilities {
231                    definition: Some(GotoCapability {
232                        link_support: Some(true),
233                        ..Default::default()
234                    }),
235                    ..Default::default()
236                }),
237                experimental: Some(json!({
238                    "serverStatusNotification": true,
239                })),
240                window: Some(WindowClientCapabilities {
241                    work_done_progress: Some(true),
242                    ..Default::default()
243                }),
244                ..Default::default()
245            },
246            trace: Default::default(),
247            workspace_folders: Default::default(),
248            client_info: Default::default(),
249            locale: Default::default(),
250        };
251
252        let this = self.clone();
253        let request = Self::request_internal::<request::Initialize>(
254            &this.next_id,
255            &this.response_handlers,
256            this.outbound_tx.read().as_ref(),
257            params,
258        );
259        request.await?;
260        Self::notify_internal::<notification::Initialized>(
261            this.outbound_tx.read().as_ref(),
262            InitializedParams {},
263        )?;
264        Ok(())
265    }
266
267    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
268        if let Some(tasks) = self.io_tasks.lock().take() {
269            let response_handlers = self.response_handlers.clone();
270            let outbound_tx = self.outbound_tx.write().take();
271            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
272            let mut output_done = self.output_done_rx.lock().take().unwrap();
273            Some(async move {
274                Self::request_internal::<request::Shutdown>(
275                    &next_id,
276                    &response_handlers,
277                    outbound_tx.as_ref(),
278                    (),
279                )
280                .await?;
281                Self::notify_internal::<notification::Exit>(outbound_tx.as_ref(), ())?;
282                drop(outbound_tx);
283                output_done.recv().await;
284                drop(tasks);
285                Ok(())
286            })
287        } else {
288            None
289        }
290    }
291
292    pub fn on_notification<T, F>(&self, mut f: F) -> Subscription
293    where
294        T: notification::Notification,
295        F: 'static + Send + Sync + FnMut(T::Params),
296    {
297        let prev_handler = self.notification_handlers.write().insert(
298            T::METHOD,
299            Box::new(
300                move |notification| match serde_json::from_str(notification) {
301                    Ok(notification) => f(notification),
302                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
303                },
304            ),
305        );
306
307        assert!(
308            prev_handler.is_none(),
309            "registered multiple handlers for the same notification"
310        );
311
312        Subscription {
313            method: T::METHOD,
314            notification_handlers: self.notification_handlers.clone(),
315        }
316    }
317
318    pub fn request<T: request::Request>(
319        self: &Arc<Self>,
320        params: T::Params,
321    ) -> impl Future<Output = Result<T::Result>>
322    where
323        T::Result: 'static + Send,
324    {
325        let this = self.clone();
326        async move {
327            this.initialized.clone().recv().await;
328            Self::request_internal::<T>(
329                &this.next_id,
330                &this.response_handlers,
331                this.outbound_tx.read().as_ref(),
332                params,
333            )
334            .await
335        }
336    }
337
338    fn request_internal<T: request::Request>(
339        next_id: &AtomicUsize,
340        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
341        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
342        params: T::Params,
343    ) -> impl 'static + Future<Output = Result<T::Result>>
344    where
345        T::Result: 'static + Send,
346    {
347        let id = next_id.fetch_add(1, SeqCst);
348        let message = serde_json::to_vec(&Request {
349            jsonrpc: JSON_RPC_VERSION,
350            id,
351            method: T::METHOD,
352            params,
353        })
354        .unwrap();
355        let mut response_handlers = response_handlers.lock();
356        let (mut tx, mut rx) = oneshot::channel();
357        response_handlers.insert(
358            id,
359            Box::new(move |result| {
360                let response = match result {
361                    Ok(response) => {
362                        serde_json::from_str(response).context("failed to deserialize response")
363                    }
364                    Err(error) => Err(anyhow!("{}", error.message)),
365                };
366                let _ = tx.try_send(response);
367            }),
368        );
369
370        let send = outbound_tx
371            .as_ref()
372            .ok_or_else(|| {
373                anyhow!("tried to send a request to a language server that has been shut down")
374            })
375            .and_then(|outbound_tx| {
376                outbound_tx.try_send(message)?;
377                Ok(())
378            });
379        async move {
380            send?;
381            rx.recv().await.unwrap()
382        }
383    }
384
385    pub fn notify<T: notification::Notification>(
386        self: &Arc<Self>,
387        params: T::Params,
388    ) -> impl Future<Output = Result<()>> {
389        let this = self.clone();
390        async move {
391            this.initialized.clone().recv().await;
392            Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
393            Ok(())
394        }
395    }
396
397    fn notify_internal<T: notification::Notification>(
398        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
399        params: T::Params,
400    ) -> Result<()> {
401        let message = serde_json::to_vec(&Notification {
402            jsonrpc: JSON_RPC_VERSION,
403            method: T::METHOD,
404            params,
405        })
406        .unwrap();
407        let outbound_tx = outbound_tx
408            .as_ref()
409            .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
410        outbound_tx.try_send(message)?;
411        Ok(())
412    }
413}
414
415impl Drop for LanguageServer {
416    fn drop(&mut self) {
417        if let Some(shutdown) = self.shutdown() {
418            self.executor.spawn(shutdown).detach();
419        }
420    }
421}
422
423impl Subscription {
424    pub fn detach(mut self) {
425        self.method = "";
426    }
427}
428
429impl Drop for Subscription {
430    fn drop(&mut self) {
431        self.notification_handlers.write().remove(self.method);
432    }
433}
434
435#[cfg(any(test, feature = "test-support"))]
436pub struct FakeLanguageServer {
437    buffer: Vec<u8>,
438    stdin: smol::io::BufReader<async_pipe::PipeReader>,
439    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
440    pub started: Arc<std::sync::atomic::AtomicBool>,
441}
442
443#[cfg(any(test, feature = "test-support"))]
444pub struct RequestId<T> {
445    id: usize,
446    _type: std::marker::PhantomData<T>,
447}
448
449#[cfg(any(test, feature = "test-support"))]
450impl LanguageServer {
451    pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
452        let stdin = async_pipe::pipe();
453        let stdout = async_pipe::pipe();
454        let mut fake = FakeLanguageServer {
455            stdin: smol::io::BufReader::new(stdin.1),
456            stdout: smol::io::BufWriter::new(stdout.0),
457            buffer: Vec::new(),
458            started: Arc::new(std::sync::atomic::AtomicBool::new(true)),
459        };
460
461        let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
462
463        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
464        fake.respond(init_id, InitializeResult::default()).await;
465        fake.receive_notification::<notification::Initialized>()
466            .await;
467
468        (server, fake)
469    }
470}
471
472#[cfg(any(test, feature = "test-support"))]
473impl FakeLanguageServer {
474    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
475        if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
476            panic!("can't simulate an LSP notification before the server has been started");
477        }
478        let message = serde_json::to_vec(&Notification {
479            jsonrpc: JSON_RPC_VERSION,
480            method: T::METHOD,
481            params,
482        })
483        .unwrap();
484        self.send(message).await;
485    }
486
487    pub async fn respond<'a, T: request::Request>(
488        &mut self,
489        request_id: RequestId<T>,
490        result: T::Result,
491    ) {
492        let result = serde_json::to_string(&result).unwrap();
493        let message = serde_json::to_vec(&AnyResponse {
494            id: request_id.id,
495            error: None,
496            result: Some(&RawValue::from_string(result).unwrap()),
497        })
498        .unwrap();
499        self.send(message).await;
500    }
501
502    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
503        loop {
504            self.receive().await;
505            if let Ok(request) = serde_json::from_slice::<Request<T::Params>>(&self.buffer) {
506                assert_eq!(request.method, T::METHOD);
507                assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
508                return (
509                    RequestId {
510                        id: request.id,
511                        _type: std::marker::PhantomData,
512                    },
513                    request.params,
514                );
515            } else {
516                println!(
517                    "skipping message in fake language server {:?}",
518                    std::str::from_utf8(&self.buffer)
519                );
520            }
521        }
522    }
523
524    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
525        self.receive().await;
526        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
527        assert_eq!(notification.method, T::METHOD);
528        notification.params
529    }
530
531    pub async fn start_progress(&mut self, token: impl Into<String>) {
532        self.notify::<notification::Progress>(ProgressParams {
533            token: NumberOrString::String(token.into()),
534            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
535        })
536        .await;
537    }
538
539    pub async fn end_progress(&mut self, token: impl Into<String>) {
540        self.notify::<notification::Progress>(ProgressParams {
541            token: NumberOrString::String(token.into()),
542            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
543        })
544        .await;
545    }
546
547    async fn send(&mut self, message: Vec<u8>) {
548        self.stdout
549            .write_all(CONTENT_LEN_HEADER.as_bytes())
550            .await
551            .unwrap();
552        self.stdout
553            .write_all((format!("{}", message.len())).as_bytes())
554            .await
555            .unwrap();
556        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
557        self.stdout.write_all(&message).await.unwrap();
558        self.stdout.flush().await.unwrap();
559    }
560
561    async fn receive(&mut self) {
562        self.buffer.clear();
563        self.stdin
564            .read_until(b'\n', &mut self.buffer)
565            .await
566            .unwrap();
567        self.stdin
568            .read_until(b'\n', &mut self.buffer)
569            .await
570            .unwrap();
571        let message_len: usize = std::str::from_utf8(&self.buffer)
572            .unwrap()
573            .strip_prefix(CONTENT_LEN_HEADER)
574            .unwrap()
575            .trim_end()
576            .parse()
577            .unwrap();
578        self.buffer.resize(message_len, 0);
579        self.stdin.read_exact(&mut self.buffer).await.unwrap();
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use gpui::TestAppContext;
587    use simplelog::SimpleLogger;
588    use unindent::Unindent;
589    use util::test::temp_tree;
590
591    #[gpui::test]
592    async fn test_basic(cx: TestAppContext) {
593        let lib_source = r#"
594            fn fun() {
595                let hello = "world";
596            }
597        "#
598        .unindent();
599        let root_dir = temp_tree(json!({
600            "Cargo.toml": r#"
601                [package]
602                name = "temp"
603                version = "0.1.0"
604                edition = "2018"
605            "#.unindent(),
606            "src": {
607                "lib.rs": &lib_source
608            }
609        }));
610        let lib_file_uri = Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
611
612        let server = cx.read(|cx| {
613            LanguageServer::new(
614                Path::new("rust-analyzer"),
615                root_dir.path(),
616                cx.background().clone(),
617            )
618            .unwrap()
619        });
620        server.next_idle_notification().await;
621
622        server
623            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
624                text_document: TextDocumentItem::new(
625                    lib_file_uri.clone(),
626                    "rust".to_string(),
627                    0,
628                    lib_source,
629                ),
630            })
631            .await
632            .unwrap();
633
634        let hover = server
635            .request::<request::HoverRequest>(HoverParams {
636                text_document_position_params: TextDocumentPositionParams {
637                    text_document: TextDocumentIdentifier::new(lib_file_uri),
638                    position: Position::new(1, 21),
639                },
640                work_done_progress_params: Default::default(),
641            })
642            .await
643            .unwrap()
644            .unwrap();
645        assert_eq!(
646            hover.contents,
647            HoverContents::Markup(MarkupContent {
648                kind: MarkupKind::Markdown,
649                value: "&str".to_string()
650            })
651        );
652    }
653
654    #[gpui::test]
655    async fn test_fake(cx: TestAppContext) {
656        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
657
658        let (server, mut fake) = LanguageServer::fake(cx.background()).await;
659
660        let (message_tx, message_rx) = channel::unbounded();
661        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
662        server
663            .on_notification::<notification::ShowMessage, _>(move |params| {
664                message_tx.try_send(params).unwrap()
665            })
666            .detach();
667        server
668            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
669                diagnostics_tx.try_send(params).unwrap()
670            })
671            .detach();
672
673        server
674            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
675                text_document: TextDocumentItem::new(
676                    Url::from_str("file://a/b").unwrap(),
677                    "rust".to_string(),
678                    0,
679                    "".to_string(),
680                ),
681            })
682            .await
683            .unwrap();
684        assert_eq!(
685            fake.receive_notification::<notification::DidOpenTextDocument>()
686                .await
687                .text_document
688                .uri
689                .as_str(),
690            "file://a/b"
691        );
692
693        fake.notify::<notification::ShowMessage>(ShowMessageParams {
694            typ: MessageType::ERROR,
695            message: "ok".to_string(),
696        })
697        .await;
698        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
699            uri: Url::from_str("file://b/c").unwrap(),
700            version: Some(5),
701            diagnostics: vec![],
702        })
703        .await;
704        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
705        assert_eq!(
706            diagnostics_rx.recv().await.unwrap().uri.as_str(),
707            "file://b/c"
708        );
709
710        drop(server);
711        let (shutdown_request, _) = fake.receive_request::<request::Shutdown>().await;
712        fake.respond(shutdown_request, ()).await;
713        fake.receive_notification::<notification::Exit>().await;
714    }
715
716    impl LanguageServer {
717        async fn next_idle_notification(self: &Arc<Self>) {
718            let (tx, rx) = channel::unbounded();
719            let _subscription =
720                self.on_notification::<ServerStatusNotification, _>(move |params| {
721                    if params.quiescent {
722                        tx.try_send(()).unwrap();
723                    }
724                });
725            let _ = rx.recv().await;
726        }
727    }
728
729    pub enum ServerStatusNotification {}
730
731    impl notification::Notification for ServerStatusNotification {
732        type Params = ServerStatusParams;
733        const METHOD: &'static str = "experimental/serverStatus";
734    }
735
736    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
737    pub struct ServerStatusParams {
738        pub quiescent: bool,
739    }
740}