lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use gpui::{executor, AppContext, Task};
  3use parking_lot::{Mutex, RwLock};
  4use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  5use serde::{Deserialize, Serialize};
  6use serde_json::{json, value::RawValue, Value};
  7use smol::{
  8    channel,
  9    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 10    process::Command,
 11};
 12use std::{
 13    collections::HashMap,
 14    future::Future,
 15    io::Write,
 16    str::FromStr,
 17    sync::{
 18        atomic::{AtomicUsize, Ordering::SeqCst},
 19        Arc,
 20    },
 21};
 22use std::{path::Path, process::Stdio};
 23use util::TryFutureExt;
 24
 25const JSON_RPC_VERSION: &'static str = "2.0";
 26const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 27
 28type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 29type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 30
 31pub struct LanguageServer {
 32    next_id: AtomicUsize,
 33    outbound_tx: channel::Sender<Vec<u8>>,
 34    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 35    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 36    _input_task: Task<Option<()>>,
 37    _output_task: Task<Option<()>>,
 38    initialized: barrier::Receiver,
 39}
 40
 41pub struct Subscription {
 42    method: &'static str,
 43    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 44}
 45
 46#[derive(Serialize)]
 47struct Request<T> {
 48    jsonrpc: &'static str,
 49    id: usize,
 50    method: &'static str,
 51    params: T,
 52}
 53
 54#[derive(Deserialize)]
 55struct Response<'a> {
 56    id: usize,
 57    #[serde(default)]
 58    error: Option<Error>,
 59    #[serde(borrow)]
 60    result: &'a RawValue,
 61}
 62
 63#[derive(Serialize)]
 64struct OutboundNotification<T> {
 65    jsonrpc: &'static str,
 66    method: &'static str,
 67    params: T,
 68}
 69
 70#[derive(Deserialize)]
 71struct InboundNotification<'a> {
 72    #[serde(borrow)]
 73    method: &'a str,
 74    #[serde(borrow)]
 75    params: &'a RawValue,
 76}
 77
 78#[derive(Debug, Deserialize)]
 79struct Error {
 80    message: String,
 81}
 82
 83impl LanguageServer {
 84    pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
 85        const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
 86        const ZED_TARGET: &'static str = env!("ZED_TARGET");
 87
 88        let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
 89        if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
 90            let rust_analyzer_path = cx
 91                .platform()
 92                .path_for_resource(Some(&rust_analyzer_name), None)?;
 93            Self::new(root_path, &rust_analyzer_path, cx.background())
 94        } else {
 95            Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
 96        }
 97    }
 98
 99    pub fn new(
100        root_path: &Path,
101        server_path: &Path,
102        background: &executor::Background,
103    ) -> Result<Arc<Self>> {
104        let mut server = Command::new(server_path)
105            .stdin(Stdio::piped())
106            .stdout(Stdio::piped())
107            .stderr(Stdio::inherit())
108            .spawn()?;
109        let mut stdin = server.stdin.take().unwrap();
110        let mut stdout = BufReader::new(server.stdout.take().unwrap());
111        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
112        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
113        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
114        let _input_task = background.spawn(
115            {
116                let notification_handlers = notification_handlers.clone();
117                let response_handlers = response_handlers.clone();
118                async move {
119                    let mut buffer = Vec::new();
120                    loop {
121                        buffer.clear();
122
123                        stdout.read_until(b'\n', &mut buffer).await?;
124                        stdout.read_until(b'\n', &mut buffer).await?;
125                        let message_len: usize = std::str::from_utf8(&buffer)?
126                            .strip_prefix(CONTENT_LEN_HEADER)
127                            .ok_or_else(|| anyhow!("invalid header"))?
128                            .trim_end()
129                            .parse()?;
130
131                        buffer.resize(message_len, 0);
132                        stdout.read_exact(&mut buffer).await?;
133
134                        if let Ok(InboundNotification { method, params }) =
135                            serde_json::from_slice(&buffer)
136                        {
137                            if let Some(handler) = notification_handlers.read().get(method) {
138                                handler(params.get());
139                            } else {
140                                log::info!(
141                                    "unhandled notification {}:\n{}",
142                                    method,
143                                    serde_json::to_string_pretty(
144                                        &Value::from_str(params.get()).unwrap()
145                                    )
146                                    .unwrap()
147                                );
148                            }
149                        } else if let Ok(Response { id, error, result }) =
150                            serde_json::from_slice(&buffer)
151                        {
152                            if let Some(handler) = response_handlers.lock().remove(&id) {
153                                if let Some(error) = error {
154                                    handler(Err(error));
155                                } else {
156                                    handler(Ok(result.get()));
157                                }
158                            }
159                        } else {
160                            return Err(anyhow!(
161                                "failed to deserialize message:\n{}",
162                                std::str::from_utf8(&buffer)?
163                            ));
164                        }
165                    }
166                }
167            }
168            .log_err(),
169        );
170        let _output_task = background.spawn(
171            async move {
172                let mut content_len_buffer = Vec::new();
173                loop {
174                    content_len_buffer.clear();
175
176                    let message = outbound_rx.recv().await?;
177                    write!(content_len_buffer, "{}", message.len()).unwrap();
178                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
179                    stdin.write_all(&content_len_buffer).await?;
180                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
181                    stdin.write_all(&message).await?;
182                }
183            }
184            .log_err(),
185        );
186
187        let (initialized_tx, initialized_rx) = barrier::channel();
188        let this = Arc::new(Self {
189            notification_handlers,
190            response_handlers,
191            next_id: Default::default(),
192            outbound_tx,
193            _input_task,
194            _output_task,
195            initialized: initialized_rx,
196        });
197
198        let root_uri =
199            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
200        background
201            .spawn({
202                let this = this.clone();
203                async move {
204                    this.init(root_uri).log_err().await;
205                    drop(initialized_tx);
206                }
207            })
208            .detach();
209
210        Ok(this)
211    }
212
213    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
214        self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
215            process_id: Default::default(),
216            root_path: Default::default(),
217            root_uri: Some(root_uri),
218            initialization_options: Default::default(),
219            capabilities: lsp_types::ClientCapabilities {
220                experimental: Some(json!({
221                    "serverStatusNotification": true,
222                })),
223                ..Default::default()
224            },
225            trace: Default::default(),
226            workspace_folders: Default::default(),
227            client_info: Default::default(),
228            locale: Default::default(),
229        })
230        .await?;
231        self.notify_internal::<lsp_types::notification::Initialized>(
232            lsp_types::InitializedParams {},
233        )
234        .await?;
235        Ok(())
236    }
237
238    pub fn on_notification<T, F>(&self, f: F) -> Subscription
239    where
240        T: lsp_types::notification::Notification,
241        F: 'static + Send + Sync + Fn(T::Params),
242    {
243        let prev_handler = self.notification_handlers.write().insert(
244            T::METHOD,
245            Box::new(
246                move |notification| match serde_json::from_str(notification) {
247                    Ok(notification) => f(notification),
248                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
249                },
250            ),
251        );
252
253        assert!(
254            prev_handler.is_none(),
255            "registered multiple handlers for the same notification"
256        );
257
258        Subscription {
259            method: T::METHOD,
260            notification_handlers: self.notification_handlers.clone(),
261        }
262    }
263
264    pub fn request<T: lsp_types::request::Request>(
265        self: Arc<Self>,
266        params: T::Params,
267    ) -> impl Future<Output = Result<T::Result>>
268    where
269        T::Result: 'static + Send,
270    {
271        let this = self.clone();
272        async move {
273            this.initialized.clone().recv().await;
274            this.request_internal::<T>(params).await
275        }
276    }
277
278    fn request_internal<T: lsp_types::request::Request>(
279        self: &Arc<Self>,
280        params: T::Params,
281    ) -> impl Future<Output = Result<T::Result>>
282    where
283        T::Result: 'static + Send,
284    {
285        let id = self.next_id.fetch_add(1, SeqCst);
286        let message = serde_json::to_vec(&Request {
287            jsonrpc: JSON_RPC_VERSION,
288            id,
289            method: T::METHOD,
290            params,
291        })
292        .unwrap();
293        let mut response_handlers = self.response_handlers.lock();
294        let (mut tx, mut rx) = oneshot::channel();
295        response_handlers.insert(
296            id,
297            Box::new(move |result| {
298                let response = match result {
299                    Ok(response) => {
300                        serde_json::from_str(response).context("failed to deserialize response")
301                    }
302                    Err(error) => Err(anyhow!("{}", error.message)),
303                };
304                let _ = tx.try_send(response);
305            }),
306        );
307
308        let this = self.clone();
309        async move {
310            this.outbound_tx.send(message).await?;
311            rx.recv().await.unwrap()
312        }
313    }
314
315    pub fn notify<T: lsp_types::notification::Notification>(
316        self: &Arc<Self>,
317        params: T::Params,
318    ) -> impl Future<Output = Result<()>> {
319        let this = self.clone();
320        async move {
321            this.initialized.clone().recv().await;
322            this.notify_internal::<T>(params).await
323        }
324    }
325
326    fn notify_internal<T: lsp_types::notification::Notification>(
327        self: &Arc<Self>,
328        params: T::Params,
329    ) -> impl Future<Output = Result<()>> {
330        let message = serde_json::to_vec(&OutboundNotification {
331            jsonrpc: JSON_RPC_VERSION,
332            method: T::METHOD,
333            params,
334        })
335        .unwrap();
336
337        let this = self.clone();
338        async move {
339            this.outbound_tx.send(message).await?;
340            Ok(())
341        }
342    }
343}
344
345impl Drop for Subscription {
346    fn drop(&mut self) {
347        self.notification_handlers.write().remove(self.method);
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use gpui::TestAppContext;
355    use unindent::Unindent;
356    use util::test::temp_tree;
357
358    #[gpui::test]
359    async fn test_basic(cx: TestAppContext) {
360        let lib_source = r#"
361            fn fun() {
362                let hello = "world";
363            }
364        "#
365        .unindent();
366        let root_dir = temp_tree(json!({
367            "Cargo.toml": r#"
368                [package]
369                name = "temp"
370                version = "0.1.0"
371                edition = "2018"
372            "#.unindent(),
373            "src": {
374                "lib.rs": &lib_source
375            }
376        }));
377        let lib_file_uri =
378            lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
379
380        let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
381        server.next_idle_notification().await;
382
383        server
384            .notify::<lsp_types::notification::DidOpenTextDocument>(
385                lsp_types::DidOpenTextDocumentParams {
386                    text_document: lsp_types::TextDocumentItem::new(
387                        lib_file_uri.clone(),
388                        "rust".to_string(),
389                        0,
390                        lib_source,
391                    ),
392                },
393            )
394            .await
395            .unwrap();
396
397        let hover = server
398            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
399                text_document_position_params: lsp_types::TextDocumentPositionParams {
400                    text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
401                    position: lsp_types::Position::new(1, 21),
402                },
403                work_done_progress_params: Default::default(),
404            })
405            .await
406            .unwrap()
407            .unwrap();
408        assert_eq!(
409            hover.contents,
410            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
411                kind: lsp_types::MarkupKind::Markdown,
412                value: "&str".to_string()
413            })
414        );
415    }
416
417    impl LanguageServer {
418        async fn next_idle_notification(self: &Arc<Self>) {
419            let (tx, rx) = channel::unbounded();
420            let _subscription =
421                self.on_notification::<ServerStatusNotification, _>(move |params| {
422                    if params.quiescent {
423                        tx.try_send(()).unwrap();
424                    }
425                });
426            let _ = rx.recv().await;
427        }
428    }
429
430    pub enum ServerStatusNotification {}
431
432    impl lsp_types::notification::Notification for ServerStatusNotification {
433        type Params = ServerStatusParams;
434        const METHOD: &'static str = "experimental/serverStatus";
435    }
436
437    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
438    pub struct ServerStatusParams {
439        pub quiescent: bool,
440    }
441}