lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use gpui::{executor, AppContext, Task};
  3use parking_lot::{Mutex, RwLock};
  4use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
  5use serde::{Deserialize, Serialize};
  6use serde_json::{json, value::RawValue};
  7use smol::{
  8    channel,
  9    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 10    process::Command,
 11};
 12use std::{
 13    collections::HashMap,
 14    future::Future,
 15    io::Write,
 16    sync::{
 17        atomic::{AtomicUsize, Ordering::SeqCst},
 18        Arc,
 19    },
 20};
 21use std::{path::Path, process::Stdio};
 22use util::TryFutureExt;
 23
 24const JSON_RPC_VERSION: &'static str = "2.0";
 25const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 26
 27type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
 28type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 29
 30pub struct LanguageServer {
 31    next_id: AtomicUsize,
 32    outbound_tx: channel::Sender<Vec<u8>>,
 33    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 34    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 35    _input_task: Task<Option<()>>,
 36    _output_task: Task<Option<()>>,
 37    initialized: barrier::Receiver,
 38}
 39
 40pub struct Subscription {
 41    method: &'static str,
 42    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 43}
 44
 45#[derive(Serialize)]
 46struct Request<T> {
 47    jsonrpc: &'static str,
 48    id: usize,
 49    method: &'static str,
 50    params: T,
 51}
 52
 53#[derive(Deserialize)]
 54struct Response<'a> {
 55    id: usize,
 56    #[serde(default)]
 57    error: Option<Error>,
 58    #[serde(borrow)]
 59    result: &'a RawValue,
 60}
 61
 62#[derive(Serialize)]
 63struct OutboundNotification<T> {
 64    jsonrpc: &'static str,
 65    method: &'static str,
 66    params: T,
 67}
 68
 69#[derive(Deserialize)]
 70struct InboundNotification<'a> {
 71    #[serde(borrow)]
 72    method: &'a str,
 73    #[serde(borrow)]
 74    params: &'a RawValue,
 75}
 76
 77#[derive(Debug, Deserialize)]
 78struct Error {
 79    message: String,
 80}
 81
 82impl LanguageServer {
 83    pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
 84        const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
 85        const ZED_TARGET: &'static str = env!("ZED_TARGET");
 86
 87        let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
 88        if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
 89            let rust_analyzer_path = cx
 90                .platform()
 91                .path_for_resource(Some(&rust_analyzer_name), None)?;
 92            Self::new(root_path, &rust_analyzer_path, cx.background())
 93        } else {
 94            Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
 95        }
 96    }
 97
 98    pub fn new(
 99        root_path: &Path,
100        server_path: &Path,
101        background: &executor::Background,
102    ) -> Result<Arc<Self>> {
103        let mut server = Command::new(server_path)
104            .stdin(Stdio::piped())
105            .stdout(Stdio::piped())
106            .stderr(Stdio::inherit())
107            .spawn()?;
108        let mut stdin = server.stdin.take().unwrap();
109        let mut stdout = BufReader::new(server.stdout.take().unwrap());
110        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
111        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
112        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
113        let _input_task = background.spawn(
114            {
115                let notification_handlers = notification_handlers.clone();
116                let response_handlers = response_handlers.clone();
117                async move {
118                    let mut buffer = Vec::new();
119                    loop {
120                        buffer.clear();
121
122                        stdout.read_until(b'\n', &mut buffer).await?;
123                        stdout.read_until(b'\n', &mut buffer).await?;
124                        let message_len: usize = std::str::from_utf8(&buffer)?
125                            .strip_prefix(CONTENT_LEN_HEADER)
126                            .ok_or_else(|| anyhow!("invalid header"))?
127                            .trim_end()
128                            .parse()?;
129
130                        buffer.resize(message_len, 0);
131                        stdout.read_exact(&mut buffer).await?;
132
133                        if let Ok(InboundNotification { method, params }) =
134                            serde_json::from_slice(&buffer)
135                        {
136                            if let Some(handler) = notification_handlers.read().get(method) {
137                                handler(params.get());
138                            }
139                        } else if let Ok(Response { id, error, result }) =
140                            serde_json::from_slice(&buffer)
141                        {
142                            if let Some(handler) = response_handlers.lock().remove(&id) {
143                                if let Some(error) = error {
144                                    handler(Err(error));
145                                } else {
146                                    handler(Ok(result.get()));
147                                }
148                            }
149                        } else {
150                            return Err(anyhow!(
151                                "failed to deserialize message:\n{}",
152                                std::str::from_utf8(&buffer)?
153                            ));
154                        }
155                    }
156                }
157            }
158            .log_err(),
159        );
160        let _output_task = background.spawn(
161            async move {
162                let mut content_len_buffer = Vec::new();
163                loop {
164                    content_len_buffer.clear();
165
166                    let message = outbound_rx.recv().await?;
167                    write!(content_len_buffer, "{}", message.len()).unwrap();
168                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
169                    stdin.write_all(&content_len_buffer).await?;
170                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
171                    stdin.write_all(&message).await?;
172                }
173            }
174            .log_err(),
175        );
176
177        let (initialized_tx, initialized_rx) = barrier::channel();
178        let this = Arc::new(Self {
179            notification_handlers,
180            response_handlers,
181            next_id: Default::default(),
182            outbound_tx,
183            _input_task,
184            _output_task,
185            initialized: initialized_rx,
186        });
187
188        let root_uri =
189            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
190        background
191            .spawn({
192                let this = this.clone();
193                async move {
194                    this.init(root_uri).log_err().await;
195                    drop(initialized_tx);
196                }
197            })
198            .detach();
199
200        Ok(this)
201    }
202
203    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
204        self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
205            process_id: Default::default(),
206            root_path: Default::default(),
207            root_uri: Some(root_uri),
208            initialization_options: Default::default(),
209            capabilities: lsp_types::ClientCapabilities {
210                experimental: Some(json!({
211                    "serverStatusNotification": true,
212                })),
213                ..Default::default()
214            },
215            trace: Default::default(),
216            workspace_folders: Default::default(),
217            client_info: Default::default(),
218            locale: Default::default(),
219        })
220        .await?;
221        self.notify_internal::<lsp_types::notification::Initialized>(
222            lsp_types::InitializedParams {},
223        )
224        .await?;
225        Ok(())
226    }
227
228    pub fn on_notification<T, F>(&self, f: F) -> Subscription
229    where
230        T: lsp_types::notification::Notification,
231        F: 'static + Send + Sync + Fn(T::Params),
232    {
233        let prev_handler = self.notification_handlers.write().insert(
234            T::METHOD,
235            Box::new(
236                move |notification| match serde_json::from_str(notification) {
237                    Ok(notification) => f(notification),
238                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
239                },
240            ),
241        );
242
243        assert!(
244            prev_handler.is_none(),
245            "registered multiple handlers for the same notification"
246        );
247
248        Subscription {
249            method: T::METHOD,
250            notification_handlers: self.notification_handlers.clone(),
251        }
252    }
253
254    pub fn request<T: lsp_types::request::Request>(
255        self: Arc<Self>,
256        params: T::Params,
257    ) -> impl Future<Output = Result<T::Result>>
258    where
259        T::Result: 'static + Send,
260    {
261        let this = self.clone();
262        async move {
263            this.initialized.clone().recv().await;
264            this.request_internal::<T>(params).await
265        }
266    }
267
268    fn request_internal<T: lsp_types::request::Request>(
269        self: &Arc<Self>,
270        params: T::Params,
271    ) -> impl Future<Output = Result<T::Result>>
272    where
273        T::Result: 'static + Send,
274    {
275        let id = self.next_id.fetch_add(1, SeqCst);
276        let message = serde_json::to_vec(&Request {
277            jsonrpc: JSON_RPC_VERSION,
278            id,
279            method: T::METHOD,
280            params,
281        })
282        .unwrap();
283        let mut response_handlers = self.response_handlers.lock();
284        let (mut tx, mut rx) = oneshot::channel();
285        response_handlers.insert(
286            id,
287            Box::new(move |result| {
288                let response = match result {
289                    Ok(response) => {
290                        serde_json::from_str(response).context("failed to deserialize response")
291                    }
292                    Err(error) => Err(anyhow!("{}", error.message)),
293                };
294                let _ = tx.try_send(response);
295            }),
296        );
297
298        let this = self.clone();
299        async move {
300            this.outbound_tx.send(message).await?;
301            rx.recv().await.unwrap()
302        }
303    }
304
305    pub fn notify<T: lsp_types::notification::Notification>(
306        self: &Arc<Self>,
307        params: T::Params,
308    ) -> impl Future<Output = Result<()>> {
309        let this = self.clone();
310        async move {
311            this.initialized.clone().recv().await;
312            this.notify_internal::<T>(params).await
313        }
314    }
315
316    fn notify_internal<T: lsp_types::notification::Notification>(
317        self: &Arc<Self>,
318        params: T::Params,
319    ) -> impl Future<Output = Result<()>> {
320        let message = serde_json::to_vec(&OutboundNotification {
321            jsonrpc: JSON_RPC_VERSION,
322            method: T::METHOD,
323            params,
324        })
325        .unwrap();
326
327        let this = self.clone();
328        async move {
329            this.outbound_tx.send(message).await?;
330            Ok(())
331        }
332    }
333}
334
335impl Drop for Subscription {
336    fn drop(&mut self) {
337        self.notification_handlers.write().remove(self.method);
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use gpui::TestAppContext;
345    use unindent::Unindent;
346    use util::test::temp_tree;
347
348    #[gpui::test]
349    async fn test_basic(cx: TestAppContext) {
350        let root_dir = temp_tree(json!({
351            "Cargo.toml": r#"
352                [package]
353                name = "temp"
354                version = "0.1.0"
355                edition = "2018"
356            "#.unindent(),
357            "src": {
358                "lib.rs": r#"
359                    fn fun() {
360                        let hello = "world";
361                    }
362                "#.unindent()
363            }
364        }));
365
366        let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
367        server.next_idle_notification().await;
368
369        let hover = server
370            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
371                text_document_position_params: lsp_types::TextDocumentPositionParams {
372                    text_document: lsp_types::TextDocumentIdentifier::new(
373                        lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap(),
374                    ),
375                    position: lsp_types::Position::new(1, 21),
376                },
377                work_done_progress_params: Default::default(),
378            })
379            .await
380            .unwrap()
381            .unwrap();
382        assert_eq!(
383            hover.contents,
384            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
385                kind: lsp_types::MarkupKind::Markdown,
386                value: "&str".to_string()
387            })
388        );
389    }
390
391    impl LanguageServer {
392        async fn next_idle_notification(self: &Arc<Self>) {
393            let (tx, rx) = channel::unbounded();
394            let _subscription =
395                self.on_notification::<ServerStatusNotification, _>(move |params| {
396                    if params.quiescent {
397                        tx.try_send(()).unwrap();
398                    }
399                });
400            let _ = rx.recv().await;
401        }
402    }
403
404    pub enum ServerStatusNotification {}
405
406    impl lsp_types::notification::Notification for ServerStatusNotification {
407        type Params = ServerStatusParams;
408        const METHOD: &'static str = "experimental/serverStatus";
409    }
410
411    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
412    pub struct ServerStatusParams {
413        pub quiescent: bool,
414    }
415}