lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, AppContext, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: channel::Sender<Vec<u8>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    _input_task: Task<Option<()>>,
 40    _output_task: Task<Option<()>>,
 41    initialized: barrier::Receiver,
 42}
 43
 44pub struct Subscription {
 45    method: &'static str,
 46    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 47}
 48
 49#[derive(Serialize, Deserialize)]
 50struct Request<'a, T> {
 51    jsonrpc: &'a str,
 52    id: usize,
 53    method: &'a str,
 54    params: T,
 55}
 56
 57#[derive(Serialize, Deserialize)]
 58struct AnyResponse<'a> {
 59    id: usize,
 60    #[serde(default)]
 61    error: Option<Error>,
 62    #[serde(borrow)]
 63    result: &'a RawValue,
 64}
 65
 66#[derive(Serialize, Deserialize)]
 67struct Notification<'a, T> {
 68    #[serde(borrow)]
 69    jsonrpc: &'a str,
 70    #[serde(borrow)]
 71    method: &'a str,
 72    params: T,
 73}
 74
 75#[derive(Deserialize)]
 76struct AnyNotification<'a> {
 77    #[serde(borrow)]
 78    method: &'a str,
 79    #[serde(borrow)]
 80    params: &'a RawValue,
 81}
 82
 83#[derive(Debug, Serialize, Deserialize)]
 84struct Error {
 85    message: String,
 86}
 87
 88impl LanguageServer {
 89    pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
 90        const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
 91        const ZED_TARGET: &'static str = env!("ZED_TARGET");
 92
 93        let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
 94        if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
 95            let rust_analyzer_path = cx
 96                .platform()
 97                .path_for_resource(Some(&rust_analyzer_name), None)?;
 98            Self::new(root_path, &rust_analyzer_path, &[], cx.background())
 99        } else {
100            Self::new(
101                root_path,
102                Path::new(&rust_analyzer_name),
103                &[],
104                cx.background(),
105            )
106        }
107    }
108
109    pub fn new(
110        root_path: &Path,
111        server_path: &Path,
112        server_args: &[&str],
113        background: &executor::Background,
114    ) -> Result<Arc<Self>> {
115        let mut server = Command::new(server_path)
116            .args(server_args)
117            .stdin(Stdio::piped())
118            .stdout(Stdio::piped())
119            .stderr(Stdio::inherit())
120            .spawn()?;
121        let stdin = server.stdin.take().unwrap();
122        let stdout = server.stdout.take().unwrap();
123        Self::new_internal(root_path, stdin, stdout, background)
124    }
125
126    fn new_internal<Stdin, Stdout>(
127        root_path: &Path,
128        stdin: Stdin,
129        stdout: Stdout,
130        background: &executor::Background,
131    ) -> Result<Arc<Self>>
132    where
133        Stdin: AsyncWrite + Unpin + Send + 'static,
134        Stdout: AsyncRead + Unpin + Send + 'static,
135    {
136        let mut stdin = BufWriter::new(stdin);
137        let mut stdout = BufReader::new(stdout);
138        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
139        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
140        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
141        let _input_task = background.spawn(
142            {
143                let notification_handlers = notification_handlers.clone();
144                let response_handlers = response_handlers.clone();
145                async move {
146                    let mut buffer = Vec::new();
147                    loop {
148                        buffer.clear();
149                        stdout.read_until(b'\n', &mut buffer).await?;
150                        stdout.read_until(b'\n', &mut buffer).await?;
151                        let message_len: usize = std::str::from_utf8(&buffer)?
152                            .strip_prefix(CONTENT_LEN_HEADER)
153                            .ok_or_else(|| anyhow!("invalid header"))?
154                            .trim_end()
155                            .parse()?;
156
157                        buffer.resize(message_len, 0);
158                        stdout.read_exact(&mut buffer).await?;
159
160                        if let Ok(AnyNotification { method, params }) =
161                            serde_json::from_slice(&buffer)
162                        {
163                            if let Some(handler) = notification_handlers.read().get(method) {
164                                handler(params.get());
165                            } else {
166                                log::info!(
167                                    "unhandled notification {}:\n{}",
168                                    method,
169                                    serde_json::to_string_pretty(
170                                        &Value::from_str(params.get()).unwrap()
171                                    )
172                                    .unwrap()
173                                );
174                            }
175                        } else if let Ok(AnyResponse { id, error, result }) =
176                            serde_json::from_slice(&buffer)
177                        {
178                            if let Some(handler) = response_handlers.lock().remove(&id) {
179                                if let Some(error) = error {
180                                    handler(Err(error));
181                                } else {
182                                    handler(Ok(result.get()));
183                                }
184                            }
185                        } else {
186                            return Err(anyhow!(
187                                "failed to deserialize message:\n{}",
188                                std::str::from_utf8(&buffer)?
189                            ));
190                        }
191                    }
192                }
193            }
194            .log_err(),
195        );
196        let _output_task = background.spawn(
197            async move {
198                let mut content_len_buffer = Vec::new();
199                loop {
200                    content_len_buffer.clear();
201
202                    let message = outbound_rx.recv().await?;
203                    write!(content_len_buffer, "{}", message.len()).unwrap();
204                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
205                    stdin.write_all(&content_len_buffer).await?;
206                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
207                    stdin.write_all(&message).await?;
208                    stdin.flush().await?;
209                }
210            }
211            .log_err(),
212        );
213
214        let (initialized_tx, initialized_rx) = barrier::channel();
215        let this = Arc::new(Self {
216            notification_handlers,
217            response_handlers,
218            next_id: Default::default(),
219            outbound_tx,
220            _input_task,
221            _output_task,
222            initialized: initialized_rx,
223        });
224
225        let root_uri =
226            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
227        background
228            .spawn({
229                let this = this.clone();
230                async move {
231                    this.init(root_uri).log_err().await;
232                    drop(initialized_tx);
233                }
234            })
235            .detach();
236
237        Ok(this)
238    }
239
240    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
241        #[allow(deprecated)]
242        let params = lsp_types::InitializeParams {
243            process_id: Default::default(),
244            root_path: Default::default(),
245            root_uri: Some(root_uri),
246            initialization_options: Default::default(),
247            capabilities: lsp_types::ClientCapabilities {
248                experimental: Some(json!({
249                    "serverStatusNotification": true,
250                })),
251                ..Default::default()
252            },
253            trace: Default::default(),
254            workspace_folders: Default::default(),
255            client_info: Default::default(),
256            locale: Default::default(),
257        };
258
259        self.request_internal::<lsp_types::request::Initialize>(params)
260            .await?;
261        self.notify_internal::<lsp_types::notification::Initialized>(
262            lsp_types::InitializedParams {},
263        )
264        .await?;
265        Ok(())
266    }
267
268    pub fn on_notification<T, F>(&self, f: F) -> Subscription
269    where
270        T: lsp_types::notification::Notification,
271        F: 'static + Send + Sync + Fn(T::Params),
272    {
273        let prev_handler = self.notification_handlers.write().insert(
274            T::METHOD,
275            Box::new(
276                move |notification| match serde_json::from_str(notification) {
277                    Ok(notification) => f(notification),
278                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
279                },
280            ),
281        );
282
283        assert!(
284            prev_handler.is_none(),
285            "registered multiple handlers for the same notification"
286        );
287
288        Subscription {
289            method: T::METHOD,
290            notification_handlers: self.notification_handlers.clone(),
291        }
292    }
293
294    pub fn request<T: lsp_types::request::Request>(
295        self: Arc<Self>,
296        params: T::Params,
297    ) -> impl Future<Output = Result<T::Result>>
298    where
299        T::Result: 'static + Send,
300    {
301        let this = self.clone();
302        async move {
303            this.initialized.clone().recv().await;
304            this.request_internal::<T>(params).await
305        }
306    }
307
308    fn request_internal<T: lsp_types::request::Request>(
309        self: &Arc<Self>,
310        params: T::Params,
311    ) -> impl Future<Output = Result<T::Result>>
312    where
313        T::Result: 'static + Send,
314    {
315        let id = self.next_id.fetch_add(1, SeqCst);
316        let message = serde_json::to_vec(&Request {
317            jsonrpc: JSON_RPC_VERSION,
318            id,
319            method: T::METHOD,
320            params,
321        })
322        .unwrap();
323        let mut response_handlers = self.response_handlers.lock();
324        let (mut tx, mut rx) = oneshot::channel();
325        response_handlers.insert(
326            id,
327            Box::new(move |result| {
328                let response = match result {
329                    Ok(response) => {
330                        serde_json::from_str(response).context("failed to deserialize response")
331                    }
332                    Err(error) => Err(anyhow!("{}", error.message)),
333                };
334                let _ = tx.try_send(response);
335            }),
336        );
337
338        let this = self.clone();
339        async move {
340            this.outbound_tx.send(message).await?;
341            rx.recv().await.unwrap()
342        }
343    }
344
345    pub fn notify<T: lsp_types::notification::Notification>(
346        self: &Arc<Self>,
347        params: T::Params,
348    ) -> impl Future<Output = Result<()>> {
349        let this = self.clone();
350        async move {
351            this.initialized.clone().recv().await;
352            this.notify_internal::<T>(params).await
353        }
354    }
355
356    fn notify_internal<T: lsp_types::notification::Notification>(
357        self: &Arc<Self>,
358        params: T::Params,
359    ) -> impl Future<Output = Result<()>> {
360        let message = serde_json::to_vec(&Notification {
361            jsonrpc: JSON_RPC_VERSION,
362            method: T::METHOD,
363            params,
364        })
365        .unwrap();
366
367        let this = self.clone();
368        async move {
369            this.outbound_tx.send(message).await?;
370            Ok(())
371        }
372    }
373}
374
375impl Subscription {
376    pub fn detach(mut self) {
377        self.method = "";
378    }
379}
380
381impl Drop for Subscription {
382    fn drop(&mut self) {
383        self.notification_handlers.write().remove(self.method);
384    }
385}
386
387#[cfg(any(test, feature = "test-support"))]
388pub struct FakeLanguageServer {
389    buffer: Vec<u8>,
390    stdin: smol::io::BufReader<async_pipe::PipeReader>,
391    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
392}
393
394#[cfg(any(test, feature = "test-support"))]
395pub struct RequestId<T> {
396    id: usize,
397    _type: std::marker::PhantomData<T>,
398}
399
400#[cfg(any(test, feature = "test-support"))]
401impl LanguageServer {
402    pub async fn fake(executor: &executor::Background) -> (Arc<Self>, FakeLanguageServer) {
403        let stdin = async_pipe::pipe();
404        let stdout = async_pipe::pipe();
405        let mut fake = FakeLanguageServer {
406            stdin: smol::io::BufReader::new(stdin.1),
407            stdout: smol::io::BufWriter::new(stdout.0),
408            buffer: Vec::new(),
409        };
410
411        let server = Self::new_internal(Path::new("/"), stdin.0, stdout.1, executor).unwrap();
412
413        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
414        fake.respond(init_id, InitializeResult::default()).await;
415        fake.receive_notification::<notification::Initialized>()
416            .await;
417
418        (server, fake)
419    }
420}
421
422#[cfg(any(test, feature = "test-support"))]
423impl FakeLanguageServer {
424    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
425        let message = serde_json::to_vec(&Notification {
426            jsonrpc: JSON_RPC_VERSION,
427            method: T::METHOD,
428            params,
429        })
430        .unwrap();
431        self.send(message).await;
432    }
433
434    pub async fn respond<'a, T: request::Request>(
435        &mut self,
436        request_id: RequestId<T>,
437        result: T::Result,
438    ) {
439        let result = serde_json::to_string(&result).unwrap();
440        let message = serde_json::to_vec(&AnyResponse {
441            id: request_id.id,
442            error: None,
443            result: &RawValue::from_string(result).unwrap(),
444        })
445        .unwrap();
446        self.send(message).await;
447    }
448
449    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
450        self.receive().await;
451        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
452        assert_eq!(request.method, T::METHOD);
453        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
454        (
455            RequestId {
456                id: request.id,
457                _type: std::marker::PhantomData,
458            },
459            request.params,
460        )
461    }
462
463    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
464        self.receive().await;
465        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
466        assert_eq!(notification.method, T::METHOD);
467        notification.params
468    }
469
470    async fn send(&mut self, message: Vec<u8>) {
471        self.stdout
472            .write_all(CONTENT_LEN_HEADER.as_bytes())
473            .await
474            .unwrap();
475        self.stdout
476            .write_all((format!("{}", message.len())).as_bytes())
477            .await
478            .unwrap();
479        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
480        self.stdout.write_all(&message).await.unwrap();
481        self.stdout.flush().await.unwrap();
482    }
483
484    async fn receive(&mut self) {
485        self.buffer.clear();
486        self.stdin
487            .read_until(b'\n', &mut self.buffer)
488            .await
489            .unwrap();
490        self.stdin
491            .read_until(b'\n', &mut self.buffer)
492            .await
493            .unwrap();
494        let message_len: usize = std::str::from_utf8(&self.buffer)
495            .unwrap()
496            .strip_prefix(CONTENT_LEN_HEADER)
497            .unwrap()
498            .trim_end()
499            .parse()
500            .unwrap();
501        self.buffer.resize(message_len, 0);
502        self.stdin.read_exact(&mut self.buffer).await.unwrap();
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use gpui::TestAppContext;
510    use simplelog::SimpleLogger;
511    use unindent::Unindent;
512    use util::test::temp_tree;
513
514    #[gpui::test]
515    async fn test_basic(cx: TestAppContext) {
516        let lib_source = r#"
517            fn fun() {
518                let hello = "world";
519            }
520        "#
521        .unindent();
522        let root_dir = temp_tree(json!({
523            "Cargo.toml": r#"
524                [package]
525                name = "temp"
526                version = "0.1.0"
527                edition = "2018"
528            "#.unindent(),
529            "src": {
530                "lib.rs": &lib_source
531            }
532        }));
533        let lib_file_uri =
534            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
535
536        let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
537        server.next_idle_notification().await;
538
539        server
540            .notify::<lsp_types::notification::DidOpenTextDocument>(
541                lsp_types::DidOpenTextDocumentParams {
542                    text_document: lsp_types::TextDocumentItem::new(
543                        lib_file_uri.clone(),
544                        "rust".to_string(),
545                        0,
546                        lib_source,
547                    ),
548                },
549            )
550            .await
551            .unwrap();
552
553        let hover = server
554            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
555                text_document_position_params: lsp_types::TextDocumentPositionParams {
556                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
557                    position: lsp_types::Position::new(1, 21),
558                },
559                work_done_progress_params: Default::default(),
560            })
561            .await
562            .unwrap()
563            .unwrap();
564        assert_eq!(
565            hover.contents,
566            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
567                kind: lsp_types::MarkupKind::Markdown,
568                value: "&str".to_string()
569            })
570        );
571    }
572
573    #[gpui::test]
574    async fn test_fake(cx: TestAppContext) {
575        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
576
577        let (server, mut fake) = LanguageServer::fake(&cx.background()).await;
578
579        let (message_tx, message_rx) = channel::unbounded();
580        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
581        server
582            .on_notification::<notification::ShowMessage, _>(move |params| {
583                message_tx.try_send(params).unwrap()
584            })
585            .detach();
586        server
587            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
588                diagnostics_tx.try_send(params).unwrap()
589            })
590            .detach();
591
592        server
593            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
594                text_document: TextDocumentItem::new(
595                    Url::from_str("file://a/b").unwrap(),
596                    "rust".to_string(),
597                    0,
598                    "".to_string(),
599                ),
600            })
601            .await
602            .unwrap();
603        assert_eq!(
604            fake.receive_notification::<notification::DidOpenTextDocument>()
605                .await
606                .text_document
607                .uri
608                .as_str(),
609            "file://a/b"
610        );
611
612        fake.notify::<notification::ShowMessage>(ShowMessageParams {
613            typ: MessageType::ERROR,
614            message: "ok".to_string(),
615        })
616        .await;
617        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
618            uri: Url::from_str("file://b/c").unwrap(),
619            version: Some(5),
620            diagnostics: vec![],
621        })
622        .await;
623        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
624        assert_eq!(
625            diagnostics_rx.recv().await.unwrap().uri.as_str(),
626            "file://b/c"
627        );
628    }
629
630    impl LanguageServer {
631        async fn next_idle_notification(self: &Arc<Self>) {
632            let (tx, rx) = channel::unbounded();
633            let _subscription =
634                self.on_notification::<ServerStatusNotification, _>(move |params| {
635                    if params.quiescent {
636                        tx.try_send(()).unwrap();
637                    }
638                });
639            let _ = rx.recv().await;
640        }
641    }
642
643    pub enum ServerStatusNotification {}
644
645    impl lsp_types::notification::Notification for ServerStatusNotification {
646        type Params = ServerStatusParams;
647        const METHOD: &'static str = "experimental/serverStatus";
648    }
649
650    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
651    pub struct ServerStatusParams {
652        pub quiescent: bool,
653    }
654}