1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
3use gpui::{executor, Task};
4use parking_lot::{Mutex, RwLock};
5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, value::RawValue, Value};
8use smol::{
9 channel,
10 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
11 process::Command,
12};
13use std::{
14 collections::HashMap,
15 future::Future,
16 io::Write,
17 str::FromStr,
18 sync::{
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 Arc,
21 },
22};
23use std::{path::Path, process::Stdio};
24use util::TryFutureExt;
25
26pub use lsp_types::*;
27
28const JSON_RPC_VERSION: &'static str = "2.0";
29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
30
31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
33
34pub struct LanguageServer {
35 next_id: AtomicUsize,
36 outbound_tx: channel::Sender<Vec<u8>>,
37 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
38 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
39 executor: Arc<executor::Background>,
40 io_tasks: Option<(Task<Option<()>>, Task<Option<()>>)>,
41 initialized: barrier::Receiver,
42 output_done_rx: Option<barrier::Receiver>,
43}
44
45pub struct Subscription {
46 method: &'static str,
47 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
48}
49
50#[derive(Serialize, Deserialize)]
51struct Request<'a, T> {
52 jsonrpc: &'a str,
53 id: usize,
54 method: &'a str,
55 params: T,
56}
57
58#[derive(Serialize, Deserialize)]
59struct AnyResponse<'a> {
60 id: usize,
61 #[serde(default)]
62 error: Option<Error>,
63 #[serde(borrow)]
64 result: &'a RawValue,
65}
66
67#[derive(Serialize, Deserialize)]
68struct Notification<'a, T> {
69 #[serde(borrow)]
70 jsonrpc: &'a str,
71 #[serde(borrow)]
72 method: &'a str,
73 params: T,
74}
75
76#[derive(Deserialize)]
77struct AnyNotification<'a> {
78 #[serde(borrow)]
79 method: &'a str,
80 #[serde(borrow)]
81 params: &'a RawValue,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct Error {
86 message: String,
87}
88
89impl LanguageServer {
90 pub fn new(
91 binary_path: &Path,
92 root_path: &Path,
93 background: Arc<executor::Background>,
94 ) -> Result<Arc<Self>> {
95 let mut server = Command::new(binary_path)
96 .stdin(Stdio::piped())
97 .stdout(Stdio::piped())
98 .stderr(Stdio::inherit())
99 .spawn()?;
100 let stdin = server.stdin.take().unwrap();
101 let stdout = server.stdout.take().unwrap();
102 Self::new_internal(stdin, stdout, root_path, background)
103 }
104
105 fn new_internal<Stdin, Stdout>(
106 stdin: Stdin,
107 stdout: Stdout,
108 root_path: &Path,
109 executor: Arc<executor::Background>,
110 ) -> Result<Arc<Self>>
111 where
112 Stdin: AsyncWrite + Unpin + Send + 'static,
113 Stdout: AsyncRead + Unpin + Send + 'static,
114 {
115 let mut stdin = BufWriter::new(stdin);
116 let mut stdout = BufReader::new(stdout);
117 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120 let input_task = executor.spawn(
121 {
122 let notification_handlers = notification_handlers.clone();
123 let response_handlers = response_handlers.clone();
124 async move {
125 let mut buffer = Vec::new();
126 loop {
127 buffer.clear();
128 stdout.read_until(b'\n', &mut buffer).await?;
129 stdout.read_until(b'\n', &mut buffer).await?;
130 let message_len: usize = std::str::from_utf8(&buffer)?
131 .strip_prefix(CONTENT_LEN_HEADER)
132 .ok_or_else(|| anyhow!("invalid header"))?
133 .trim_end()
134 .parse()?;
135
136 buffer.resize(message_len, 0);
137 stdout.read_exact(&mut buffer).await?;
138
139 if let Ok(AnyNotification { method, params }) =
140 serde_json::from_slice(&buffer)
141 {
142 if let Some(handler) = notification_handlers.read().get(method) {
143 handler(params.get());
144 } else {
145 log::info!(
146 "unhandled notification {}:\n{}",
147 method,
148 serde_json::to_string_pretty(
149 &Value::from_str(params.get()).unwrap()
150 )
151 .unwrap()
152 );
153 }
154 } else if let Ok(AnyResponse { id, error, result }) =
155 serde_json::from_slice(&buffer)
156 {
157 if let Some(handler) = response_handlers.lock().remove(&id) {
158 if let Some(error) = error {
159 handler(Err(error));
160 } else {
161 handler(Ok(result.get()));
162 }
163 }
164 } else {
165 return Err(anyhow!(
166 "failed to deserialize message:\n{}",
167 std::str::from_utf8(&buffer)?
168 ));
169 }
170 }
171 }
172 }
173 .log_err(),
174 );
175 let (output_done_tx, output_done_rx) = barrier::channel();
176 let output_task = executor.spawn(
177 async move {
178 let mut content_len_buffer = Vec::new();
179 while let Ok(message) = outbound_rx.recv().await {
180 content_len_buffer.clear();
181 write!(content_len_buffer, "{}", message.len()).unwrap();
182 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
183 stdin.write_all(&content_len_buffer).await?;
184 stdin.write_all("\r\n\r\n".as_bytes()).await?;
185 stdin.write_all(&message).await?;
186 stdin.flush().await?;
187 }
188 drop(output_done_tx);
189 Ok(())
190 }
191 .log_err(),
192 );
193
194 let (initialized_tx, initialized_rx) = barrier::channel();
195 let this = Arc::new(Self {
196 notification_handlers,
197 response_handlers,
198 next_id: Default::default(),
199 outbound_tx,
200 executor: executor.clone(),
201 io_tasks: Some((input_task, output_task)),
202 initialized: initialized_rx,
203 output_done_rx: Some(output_done_rx),
204 });
205
206 let root_uri =
207 lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
208 executor
209 .spawn({
210 let this = this.clone();
211 async move {
212 this.init(root_uri).log_err().await;
213 drop(initialized_tx);
214 }
215 })
216 .detach();
217
218 Ok(this)
219 }
220
221 async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
222 #[allow(deprecated)]
223 let params = lsp_types::InitializeParams {
224 process_id: Default::default(),
225 root_path: Default::default(),
226 root_uri: Some(root_uri),
227 initialization_options: Default::default(),
228 capabilities: lsp_types::ClientCapabilities {
229 experimental: Some(json!({
230 "serverStatusNotification": true,
231 })),
232 ..Default::default()
233 },
234 trace: Default::default(),
235 workspace_folders: Default::default(),
236 client_info: Default::default(),
237 locale: Default::default(),
238 };
239
240 let this = self.clone();
241 Self::request_internal::<lsp_types::request::Initialize>(
242 &this.next_id,
243 &this.response_handlers,
244 &this.outbound_tx,
245 params,
246 )
247 .await?;
248 Self::notify_internal::<lsp_types::notification::Initialized>(
249 &this.outbound_tx,
250 lsp_types::InitializedParams {},
251 )?;
252 Ok(())
253 }
254
255 pub fn on_notification<T, F>(&self, f: F) -> Subscription
256 where
257 T: lsp_types::notification::Notification,
258 F: 'static + Send + Sync + Fn(T::Params),
259 {
260 let prev_handler = self.notification_handlers.write().insert(
261 T::METHOD,
262 Box::new(
263 move |notification| match serde_json::from_str(notification) {
264 Ok(notification) => f(notification),
265 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
266 },
267 ),
268 );
269
270 assert!(
271 prev_handler.is_none(),
272 "registered multiple handlers for the same notification"
273 );
274
275 Subscription {
276 method: T::METHOD,
277 notification_handlers: self.notification_handlers.clone(),
278 }
279 }
280
281 pub fn request<T: lsp_types::request::Request>(
282 self: Arc<Self>,
283 params: T::Params,
284 ) -> impl Future<Output = Result<T::Result>>
285 where
286 T::Result: 'static + Send,
287 {
288 let this = self.clone();
289 async move {
290 this.initialized.clone().recv().await;
291 Self::request_internal::<T>(
292 &this.next_id,
293 &this.response_handlers,
294 &this.outbound_tx,
295 params,
296 )
297 .await
298 }
299 }
300
301 fn request_internal<T: lsp_types::request::Request>(
302 next_id: &AtomicUsize,
303 response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
304 outbound_tx: &channel::Sender<Vec<u8>>,
305 params: T::Params,
306 ) -> impl Future<Output = Result<T::Result>>
307 where
308 T::Result: 'static + Send,
309 {
310 let id = next_id.fetch_add(1, SeqCst);
311 let message = serde_json::to_vec(&Request {
312 jsonrpc: JSON_RPC_VERSION,
313 id,
314 method: T::METHOD,
315 params,
316 })
317 .unwrap();
318 let mut response_handlers = response_handlers.lock();
319 let (mut tx, mut rx) = oneshot::channel();
320 response_handlers.insert(
321 id,
322 Box::new(move |result| {
323 let response = match result {
324 Ok(response) => {
325 serde_json::from_str(response).context("failed to deserialize response")
326 }
327 Err(error) => Err(anyhow!("{}", error.message)),
328 };
329 let _ = tx.try_send(response);
330 }),
331 );
332
333 let send = outbound_tx.try_send(message);
334 async move {
335 send?;
336 rx.recv().await.unwrap()
337 }
338 }
339
340 pub fn notify<T: lsp_types::notification::Notification>(
341 self: &Arc<Self>,
342 params: T::Params,
343 ) -> impl Future<Output = Result<()>> {
344 let this = self.clone();
345 async move {
346 this.initialized.clone().recv().await;
347 Self::notify_internal::<T>(&this.outbound_tx, params)?;
348 Ok(())
349 }
350 }
351
352 fn notify_internal<T: lsp_types::notification::Notification>(
353 outbound_tx: &channel::Sender<Vec<u8>>,
354 params: T::Params,
355 ) -> Result<()> {
356 let message = serde_json::to_vec(&Notification {
357 jsonrpc: JSON_RPC_VERSION,
358 method: T::METHOD,
359 params,
360 })
361 .unwrap();
362 outbound_tx.try_send(message)?;
363 Ok(())
364 }
365}
366
367impl Drop for LanguageServer {
368 fn drop(&mut self) {
369 let tasks = self.io_tasks.take();
370 let response_handlers = self.response_handlers.clone();
371 let outbound_tx = self.outbound_tx.clone();
372 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
373 let mut output_done = self.output_done_rx.take().unwrap();
374 self.executor.spawn_critical(
375 async move {
376 Self::request_internal::<lsp_types::request::Shutdown>(
377 &next_id,
378 &response_handlers,
379 &outbound_tx,
380 (),
381 )
382 .await?;
383 Self::notify_internal::<lsp_types::notification::Exit>(&outbound_tx, ())?;
384 drop(outbound_tx);
385 output_done.recv().await;
386 drop(tasks);
387 Ok(())
388 }
389 .log_err(),
390 )
391 }
392}
393
394impl Subscription {
395 pub fn detach(mut self) {
396 self.method = "";
397 }
398}
399
400impl Drop for Subscription {
401 fn drop(&mut self) {
402 self.notification_handlers.write().remove(self.method);
403 }
404}
405
406#[cfg(any(test, feature = "test-support"))]
407pub struct FakeLanguageServer {
408 buffer: Vec<u8>,
409 stdin: smol::io::BufReader<async_pipe::PipeReader>,
410 stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
411}
412
413#[cfg(any(test, feature = "test-support"))]
414pub struct RequestId<T> {
415 id: usize,
416 _type: std::marker::PhantomData<T>,
417}
418
419#[cfg(any(test, feature = "test-support"))]
420impl LanguageServer {
421 pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
422 let stdin = async_pipe::pipe();
423 let stdout = async_pipe::pipe();
424 let mut fake = FakeLanguageServer {
425 stdin: smol::io::BufReader::new(stdin.1),
426 stdout: smol::io::BufWriter::new(stdout.0),
427 buffer: Vec::new(),
428 };
429
430 let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
431
432 let (init_id, _) = fake.receive_request::<request::Initialize>().await;
433 fake.respond(init_id, InitializeResult::default()).await;
434 fake.receive_notification::<notification::Initialized>()
435 .await;
436
437 (server, fake)
438 }
439}
440
441#[cfg(any(test, feature = "test-support"))]
442impl FakeLanguageServer {
443 pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
444 let message = serde_json::to_vec(&Notification {
445 jsonrpc: JSON_RPC_VERSION,
446 method: T::METHOD,
447 params,
448 })
449 .unwrap();
450 self.send(message).await;
451 }
452
453 pub async fn respond<'a, T: request::Request>(
454 &mut self,
455 request_id: RequestId<T>,
456 result: T::Result,
457 ) {
458 let result = serde_json::to_string(&result).unwrap();
459 let message = serde_json::to_vec(&AnyResponse {
460 id: request_id.id,
461 error: None,
462 result: &RawValue::from_string(result).unwrap(),
463 })
464 .unwrap();
465 self.send(message).await;
466 }
467
468 pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
469 self.receive().await;
470 let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
471 assert_eq!(request.method, T::METHOD);
472 assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
473 (
474 RequestId {
475 id: request.id,
476 _type: std::marker::PhantomData,
477 },
478 request.params,
479 )
480 }
481
482 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
483 self.receive().await;
484 let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
485 assert_eq!(notification.method, T::METHOD);
486 notification.params
487 }
488
489 async fn send(&mut self, message: Vec<u8>) {
490 self.stdout
491 .write_all(CONTENT_LEN_HEADER.as_bytes())
492 .await
493 .unwrap();
494 self.stdout
495 .write_all((format!("{}", message.len())).as_bytes())
496 .await
497 .unwrap();
498 self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
499 self.stdout.write_all(&message).await.unwrap();
500 self.stdout.flush().await.unwrap();
501 }
502
503 async fn receive(&mut self) {
504 self.buffer.clear();
505 self.stdin
506 .read_until(b'\n', &mut self.buffer)
507 .await
508 .unwrap();
509 self.stdin
510 .read_until(b'\n', &mut self.buffer)
511 .await
512 .unwrap();
513 let message_len: usize = std::str::from_utf8(&self.buffer)
514 .unwrap()
515 .strip_prefix(CONTENT_LEN_HEADER)
516 .unwrap()
517 .trim_end()
518 .parse()
519 .unwrap();
520 self.buffer.resize(message_len, 0);
521 self.stdin.read_exact(&mut self.buffer).await.unwrap();
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use gpui::TestAppContext;
529 use simplelog::SimpleLogger;
530 use unindent::Unindent;
531 use util::test::temp_tree;
532
533 #[gpui::test]
534 async fn test_basic(cx: TestAppContext) {
535 let lib_source = r#"
536 fn fun() {
537 let hello = "world";
538 }
539 "#
540 .unindent();
541 let root_dir = temp_tree(json!({
542 "Cargo.toml": r#"
543 [package]
544 name = "temp"
545 version = "0.1.0"
546 edition = "2018"
547 "#.unindent(),
548 "src": {
549 "lib.rs": &lib_source
550 }
551 }));
552 let lib_file_uri =
553 lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
554
555 let server = cx.read(|cx| {
556 LanguageServer::new(
557 Path::new("rust-analyzer"),
558 root_dir.path(),
559 cx.background().clone(),
560 )
561 .unwrap()
562 });
563 server.next_idle_notification().await;
564
565 server
566 .notify::<lsp_types::notification::DidOpenTextDocument>(
567 lsp_types::DidOpenTextDocumentParams {
568 text_document: lsp_types::TextDocumentItem::new(
569 lib_file_uri.clone(),
570 "rust".to_string(),
571 0,
572 lib_source,
573 ),
574 },
575 )
576 .await
577 .unwrap();
578
579 let hover = server
580 .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
581 text_document_position_params: lsp_types::TextDocumentPositionParams {
582 text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
583 position: lsp_types::Position::new(1, 21),
584 },
585 work_done_progress_params: Default::default(),
586 })
587 .await
588 .unwrap()
589 .unwrap();
590 assert_eq!(
591 hover.contents,
592 lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
593 kind: lsp_types::MarkupKind::Markdown,
594 value: "&str".to_string()
595 })
596 );
597 }
598
599 #[gpui::test]
600 async fn test_fake(cx: TestAppContext) {
601 SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
602
603 let (server, mut fake) = LanguageServer::fake(cx.background()).await;
604
605 let (message_tx, message_rx) = channel::unbounded();
606 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
607 server
608 .on_notification::<notification::ShowMessage, _>(move |params| {
609 message_tx.try_send(params).unwrap()
610 })
611 .detach();
612 server
613 .on_notification::<notification::PublishDiagnostics, _>(move |params| {
614 diagnostics_tx.try_send(params).unwrap()
615 })
616 .detach();
617
618 server
619 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
620 text_document: TextDocumentItem::new(
621 Url::from_str("file://a/b").unwrap(),
622 "rust".to_string(),
623 0,
624 "".to_string(),
625 ),
626 })
627 .await
628 .unwrap();
629 assert_eq!(
630 fake.receive_notification::<notification::DidOpenTextDocument>()
631 .await
632 .text_document
633 .uri
634 .as_str(),
635 "file://a/b"
636 );
637
638 fake.notify::<notification::ShowMessage>(ShowMessageParams {
639 typ: MessageType::ERROR,
640 message: "ok".to_string(),
641 })
642 .await;
643 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
644 uri: Url::from_str("file://b/c").unwrap(),
645 version: Some(5),
646 diagnostics: vec![],
647 })
648 .await;
649 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
650 assert_eq!(
651 diagnostics_rx.recv().await.unwrap().uri.as_str(),
652 "file://b/c"
653 );
654
655 drop(server);
656 let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
657 fake.respond(shutdown_request, ()).await;
658 fake.receive_notification::<lsp_types::notification::Exit>()
659 .await;
660 }
661
662 impl LanguageServer {
663 async fn next_idle_notification(self: &Arc<Self>) {
664 let (tx, rx) = channel::unbounded();
665 let _subscription =
666 self.on_notification::<ServerStatusNotification, _>(move |params| {
667 if params.quiescent {
668 tx.try_send(()).unwrap();
669 }
670 });
671 let _ = rx.recv().await;
672 }
673 }
674
675 pub enum ServerStatusNotification {}
676
677 impl lsp_types::notification::Notification for ServerStatusNotification {
678 type Params = ServerStatusParams;
679 const METHOD: &'static str = "experimental/serverStatus";
680 }
681
682 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
683 pub struct ServerStatusParams {
684 pub quiescent: bool,
685 }
686}