lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
  3use gpui::{executor, AppContext, Task};
  4use parking_lot::{Mutex, RwLock};
  5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  6use serde::{Deserialize, Serialize};
  7use serde_json::{json, value::RawValue, Value};
  8use smol::{
  9    channel,
 10    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 11    process::Command,
 12};
 13use std::{
 14    collections::HashMap,
 15    future::Future,
 16    io::Write,
 17    str::FromStr,
 18    sync::{
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23use std::{path::Path, process::Stdio};
 24use util::TryFutureExt;
 25
 26pub use lsp_types::*;
 27
 28const JSON_RPC_VERSION: &'static str = "2.0";
 29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 30
 31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 33
 34pub struct LanguageServer {
 35    next_id: AtomicUsize,
 36    outbound_tx: channel::Sender<Vec<u8>>,
 37    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 38    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 39    _input_task: Task<Option<()>>,
 40    _output_task: Task<Option<()>>,
 41    initialized: barrier::Receiver,
 42}
 43
 44pub struct Subscription {
 45    method: &'static str,
 46    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 47}
 48
 49#[derive(Serialize, Deserialize)]
 50struct Request<'a, T> {
 51    jsonrpc: &'a str,
 52    id: usize,
 53    method: &'a str,
 54    params: T,
 55}
 56
 57#[derive(Serialize, Deserialize)]
 58struct AnyResponse<'a> {
 59    id: usize,
 60    #[serde(default)]
 61    error: Option<Error>,
 62    #[serde(borrow)]
 63    result: &'a RawValue,
 64}
 65
 66#[derive(Serialize, Deserialize)]
 67struct Notification<'a, T> {
 68    #[serde(borrow)]
 69    jsonrpc: &'a str,
 70    #[serde(borrow)]
 71    method: &'a str,
 72    params: T,
 73}
 74
 75#[derive(Deserialize)]
 76struct AnyNotification<'a> {
 77    #[serde(borrow)]
 78    method: &'a str,
 79    #[serde(borrow)]
 80    params: &'a RawValue,
 81}
 82
 83#[derive(Debug, Serialize, Deserialize)]
 84struct Error {
 85    message: String,
 86}
 87
 88impl LanguageServer {
 89    pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
 90        const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
 91        const ZED_TARGET: &'static str = env!("ZED_TARGET");
 92
 93        let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
 94        if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
 95            let rust_analyzer_path = cx
 96                .platform()
 97                .path_for_resource(Some(&rust_analyzer_name), None)?;
 98            Self::new(root_path, &rust_analyzer_path, &[], cx.background())
 99        } else {
100            Self::new(
101                root_path,
102                Path::new(&rust_analyzer_name),
103                &[],
104                cx.background(),
105            )
106        }
107    }
108
109    pub fn new(
110        root_path: &Path,
111        server_path: &Path,
112        server_args: &[&str],
113        background: &executor::Background,
114    ) -> Result<Arc<Self>> {
115        let mut server = Command::new(server_path)
116            .args(server_args)
117            .stdin(Stdio::piped())
118            .stdout(Stdio::piped())
119            .stderr(Stdio::inherit())
120            .spawn()?;
121        let stdin = server.stdin.take().unwrap();
122        let stdout = server.stdout.take().unwrap();
123        Self::new_internal(root_path, stdin, stdout, background)
124    }
125
126    fn new_internal<Stdin, Stdout>(
127        root_path: &Path,
128        stdin: Stdin,
129        stdout: Stdout,
130        background: &executor::Background,
131    ) -> Result<Arc<Self>>
132    where
133        Stdin: AsyncWrite + Unpin + Send + 'static,
134        Stdout: AsyncRead + Unpin + Send + 'static,
135    {
136        let mut stdin = BufWriter::new(stdin);
137        let mut stdout = BufReader::new(stdout);
138        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
139        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
140        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
141        let _input_task = background.spawn(
142            {
143                let notification_handlers = notification_handlers.clone();
144                let response_handlers = response_handlers.clone();
145                async move {
146                    let mut buffer = Vec::new();
147                    loop {
148                        buffer.clear();
149                        stdout.read_until(b'\n', &mut buffer).await?;
150                        stdout.read_until(b'\n', &mut buffer).await?;
151                        let message_len: usize = std::str::from_utf8(&buffer)?
152                            .strip_prefix(CONTENT_LEN_HEADER)
153                            .ok_or_else(|| anyhow!("invalid header"))?
154                            .trim_end()
155                            .parse()?;
156
157                        buffer.resize(message_len, 0);
158                        stdout.read_exact(&mut buffer).await?;
159
160                        println!("{}", std::str::from_utf8(&buffer).unwrap());
161                        if let Ok(AnyNotification { method, params }) =
162                            serde_json::from_slice(&buffer)
163                        {
164                            if let Some(handler) = notification_handlers.read().get(method) {
165                                handler(params.get());
166                            } else {
167                                log::info!(
168                                    "unhandled notification {}:\n{}",
169                                    method,
170                                    serde_json::to_string_pretty(
171                                        &Value::from_str(params.get()).unwrap()
172                                    )
173                                    .unwrap()
174                                );
175                            }
176                        } else if let Ok(AnyResponse { id, error, result }) =
177                            serde_json::from_slice(&buffer)
178                        {
179                            if let Some(handler) = response_handlers.lock().remove(&id) {
180                                if let Some(error) = error {
181                                    handler(Err(error));
182                                } else {
183                                    handler(Ok(result.get()));
184                                }
185                            }
186                        } else {
187                            return Err(anyhow!(
188                                "failed to deserialize message:\n{}",
189                                std::str::from_utf8(&buffer)?
190                            ));
191                        }
192                    }
193                }
194            }
195            .log_err(),
196        );
197        let _output_task = background.spawn(
198            async move {
199                let mut content_len_buffer = Vec::new();
200                loop {
201                    content_len_buffer.clear();
202
203                    let message = outbound_rx.recv().await?;
204                    println!("{}", std::str::from_utf8(&message).unwrap());
205                    write!(content_len_buffer, "{}", message.len()).unwrap();
206                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
207                    stdin.write_all(&content_len_buffer).await?;
208                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
209                    stdin.write_all(&message).await?;
210                    stdin.flush().await?;
211                }
212            }
213            .log_err(),
214        );
215
216        let (initialized_tx, initialized_rx) = barrier::channel();
217        let this = Arc::new(Self {
218            notification_handlers,
219            response_handlers,
220            next_id: Default::default(),
221            outbound_tx,
222            _input_task,
223            _output_task,
224            initialized: initialized_rx,
225        });
226
227        let root_uri =
228            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
229        background
230            .spawn({
231                let this = this.clone();
232                async move {
233                    this.init(root_uri).log_err().await;
234                    drop(initialized_tx);
235                }
236            })
237            .detach();
238
239        Ok(this)
240    }
241
242    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
243        #[allow(deprecated)]
244        let params = lsp_types::InitializeParams {
245            process_id: Default::default(),
246            root_path: Default::default(),
247            root_uri: Some(root_uri),
248            initialization_options: Default::default(),
249            capabilities: lsp_types::ClientCapabilities {
250                experimental: Some(json!({
251                    "serverStatusNotification": true,
252                })),
253                ..Default::default()
254            },
255            trace: Default::default(),
256            workspace_folders: Default::default(),
257            client_info: Default::default(),
258            locale: Default::default(),
259        };
260
261        self.request_internal::<lsp_types::request::Initialize>(params)
262            .await?;
263        self.notify_internal::<lsp_types::notification::Initialized>(
264            lsp_types::InitializedParams {},
265        )
266        .await?;
267        Ok(())
268    }
269
270    pub fn on_notification<T, F>(&self, f: F) -> Subscription
271    where
272        T: lsp_types::notification::Notification,
273        F: 'static + Send + Sync + Fn(T::Params),
274    {
275        let prev_handler = self.notification_handlers.write().insert(
276            T::METHOD,
277            Box::new(
278                move |notification| match serde_json::from_str(notification) {
279                    Ok(notification) => f(notification),
280                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
281                },
282            ),
283        );
284
285        assert!(
286            prev_handler.is_none(),
287            "registered multiple handlers for the same notification"
288        );
289
290        Subscription {
291            method: T::METHOD,
292            notification_handlers: self.notification_handlers.clone(),
293        }
294    }
295
296    pub fn request<T: lsp_types::request::Request>(
297        self: Arc<Self>,
298        params: T::Params,
299    ) -> impl Future<Output = Result<T::Result>>
300    where
301        T::Result: 'static + Send,
302    {
303        let this = self.clone();
304        async move {
305            this.initialized.clone().recv().await;
306            this.request_internal::<T>(params).await
307        }
308    }
309
310    fn request_internal<T: lsp_types::request::Request>(
311        self: &Arc<Self>,
312        params: T::Params,
313    ) -> impl Future<Output = Result<T::Result>>
314    where
315        T::Result: 'static + Send,
316    {
317        let id = self.next_id.fetch_add(1, SeqCst);
318        let message = serde_json::to_vec(&Request {
319            jsonrpc: JSON_RPC_VERSION,
320            id,
321            method: T::METHOD,
322            params,
323        })
324        .unwrap();
325        let mut response_handlers = self.response_handlers.lock();
326        let (mut tx, mut rx) = oneshot::channel();
327        response_handlers.insert(
328            id,
329            Box::new(move |result| {
330                let response = match result {
331                    Ok(response) => {
332                        serde_json::from_str(response).context("failed to deserialize response")
333                    }
334                    Err(error) => Err(anyhow!("{}", error.message)),
335                };
336                let _ = tx.try_send(response);
337            }),
338        );
339
340        let this = self.clone();
341        async move {
342            this.outbound_tx.send(message).await?;
343            rx.recv().await.unwrap()
344        }
345    }
346
347    pub fn notify<T: lsp_types::notification::Notification>(
348        self: &Arc<Self>,
349        params: T::Params,
350    ) -> impl Future<Output = Result<()>> {
351        let this = self.clone();
352        async move {
353            this.initialized.clone().recv().await;
354            this.notify_internal::<T>(params).await
355        }
356    }
357
358    fn notify_internal<T: lsp_types::notification::Notification>(
359        self: &Arc<Self>,
360        params: T::Params,
361    ) -> impl Future<Output = Result<()>> {
362        let message = serde_json::to_vec(&Notification {
363            jsonrpc: JSON_RPC_VERSION,
364            method: T::METHOD,
365            params,
366        })
367        .unwrap();
368
369        let this = self.clone();
370        async move {
371            this.outbound_tx.send(message).await?;
372            Ok(())
373        }
374    }
375}
376
377impl Subscription {
378    pub fn detach(mut self) {
379        self.method = "";
380    }
381}
382
383impl Drop for Subscription {
384    fn drop(&mut self) {
385        self.notification_handlers.write().remove(self.method);
386    }
387}
388
389#[cfg(any(test, feature = "test-support"))]
390pub struct FakeLanguageServer {
391    buffer: Vec<u8>,
392    stdin: smol::io::BufReader<async_pipe::PipeReader>,
393    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
394}
395
396#[cfg(any(test, feature = "test-support"))]
397pub struct RequestId<T> {
398    id: usize,
399    _type: std::marker::PhantomData<T>,
400}
401
402#[cfg(any(test, feature = "test-support"))]
403impl LanguageServer {
404    pub async fn fake(executor: &executor::Background) -> (Arc<Self>, FakeLanguageServer) {
405        let stdin = async_pipe::pipe();
406        let stdout = async_pipe::pipe();
407        let mut fake = FakeLanguageServer {
408            stdin: smol::io::BufReader::new(stdin.1),
409            stdout: smol::io::BufWriter::new(stdout.0),
410            buffer: Vec::new(),
411        };
412
413        let server = Self::new_internal(Path::new("/"), stdin.0, stdout.1, executor).unwrap();
414
415        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
416        fake.respond(init_id, InitializeResult::default()).await;
417        fake.receive_notification::<notification::Initialized>()
418            .await;
419
420        (server, fake)
421    }
422}
423
424#[cfg(any(test, feature = "test-support"))]
425impl FakeLanguageServer {
426    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
427        let message = serde_json::to_vec(&Notification {
428            jsonrpc: JSON_RPC_VERSION,
429            method: T::METHOD,
430            params,
431        })
432        .unwrap();
433        self.send(message).await;
434    }
435
436    pub async fn respond<'a, T: request::Request>(
437        &mut self,
438        request_id: RequestId<T>,
439        result: T::Result,
440    ) {
441        let result = serde_json::to_string(&result).unwrap();
442        let message = serde_json::to_vec(&AnyResponse {
443            id: request_id.id,
444            error: None,
445            result: &RawValue::from_string(result).unwrap(),
446        })
447        .unwrap();
448        self.send(message).await;
449    }
450
451    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
452        self.receive().await;
453        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
454        assert_eq!(request.method, T::METHOD);
455        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
456        (
457            RequestId {
458                id: request.id,
459                _type: std::marker::PhantomData,
460            },
461            request.params,
462        )
463    }
464
465    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
466        self.receive().await;
467        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
468        assert_eq!(notification.method, T::METHOD);
469        notification.params
470    }
471
472    async fn send(&mut self, message: Vec<u8>) {
473        self.stdout
474            .write_all(CONTENT_LEN_HEADER.as_bytes())
475            .await
476            .unwrap();
477        self.stdout
478            .write_all((format!("{}", message.len())).as_bytes())
479            .await
480            .unwrap();
481        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
482        self.stdout.write_all(&message).await.unwrap();
483        self.stdout.flush().await.unwrap();
484    }
485
486    async fn receive(&mut self) {
487        self.buffer.clear();
488        self.stdin
489            .read_until(b'\n', &mut self.buffer)
490            .await
491            .unwrap();
492        self.stdin
493            .read_until(b'\n', &mut self.buffer)
494            .await
495            .unwrap();
496        let message_len: usize = std::str::from_utf8(&self.buffer)
497            .unwrap()
498            .strip_prefix(CONTENT_LEN_HEADER)
499            .unwrap()
500            .trim_end()
501            .parse()
502            .unwrap();
503        self.buffer.resize(message_len, 0);
504        self.stdin.read_exact(&mut self.buffer).await.unwrap();
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511    use gpui::TestAppContext;
512    use simplelog::SimpleLogger;
513    use unindent::Unindent;
514    use util::test::temp_tree;
515
516    #[gpui::test]
517    async fn test_basic(cx: TestAppContext) {
518        let lib_source = r#"
519            fn fun() {
520                let hello = "world";
521            }
522        "#
523        .unindent();
524        let root_dir = temp_tree(json!({
525            "Cargo.toml": r#"
526                [package]
527                name = "temp"
528                version = "0.1.0"
529                edition = "2018"
530            "#.unindent(),
531            "src": {
532                "lib.rs": &lib_source
533            }
534        }));
535        let lib_file_uri =
536            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
537
538        let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
539        server.next_idle_notification().await;
540
541        server
542            .notify::<lsp_types::notification::DidOpenTextDocument>(
543                lsp_types::DidOpenTextDocumentParams {
544                    text_document: lsp_types::TextDocumentItem::new(
545                        lib_file_uri.clone(),
546                        "rust".to_string(),
547                        0,
548                        lib_source,
549                    ),
550                },
551            )
552            .await
553            .unwrap();
554
555        let hover = server
556            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
557                text_document_position_params: lsp_types::TextDocumentPositionParams {
558                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
559                    position: lsp_types::Position::new(1, 21),
560                },
561                work_done_progress_params: Default::default(),
562            })
563            .await
564            .unwrap()
565            .unwrap();
566        assert_eq!(
567            hover.contents,
568            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
569                kind: lsp_types::MarkupKind::Markdown,
570                value: "&str".to_string()
571            })
572        );
573    }
574
575    #[gpui::test]
576    async fn test_fake(cx: TestAppContext) {
577        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
578
579        let (server, mut fake) = LanguageServer::fake(&cx.background()).await;
580
581        let (message_tx, message_rx) = channel::unbounded();
582        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
583        server
584            .on_notification::<notification::ShowMessage, _>(move |params| {
585                message_tx.try_send(params).unwrap()
586            })
587            .detach();
588        server
589            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
590                diagnostics_tx.try_send(params).unwrap()
591            })
592            .detach();
593
594        server
595            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
596                text_document: TextDocumentItem::new(
597                    Url::from_str("file://a/b").unwrap(),
598                    "rust".to_string(),
599                    0,
600                    "".to_string(),
601                ),
602            })
603            .await
604            .unwrap();
605        assert_eq!(
606            fake.receive_notification::<notification::DidOpenTextDocument>()
607                .await
608                .text_document
609                .uri
610                .as_str(),
611            "file://a/b"
612        );
613
614        fake.notify::<notification::ShowMessage>(ShowMessageParams {
615            typ: MessageType::ERROR,
616            message: "ok".to_string(),
617        })
618        .await;
619        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
620            uri: Url::from_str("file://b/c").unwrap(),
621            version: Some(5),
622            diagnostics: vec![],
623        })
624        .await;
625        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
626        assert_eq!(
627            diagnostics_rx.recv().await.unwrap().uri.as_str(),
628            "file://b/c"
629        );
630    }
631
632    impl LanguageServer {
633        async fn next_idle_notification(self: &Arc<Self>) {
634            let (tx, rx) = channel::unbounded();
635            let _subscription =
636                self.on_notification::<ServerStatusNotification, _>(move |params| {
637                    if params.quiescent {
638                        tx.try_send(()).unwrap();
639                    }
640                });
641            let _ = rx.recv().await;
642        }
643    }
644
645    pub enum ServerStatusNotification {}
646
647    impl lsp_types::notification::Notification for ServerStatusNotification {
648        type Params = ServerStatusParams;
649        const METHOD: &'static str = "experimental/serverStatus";
650    }
651
652    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
653    pub struct ServerStatusParams {
654        pub quiescent: bool,
655    }
656}