1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
3use gpui::{executor, Task};
4use parking_lot::{Mutex, RwLock};
5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, value::RawValue, Value};
8use smol::{
9 channel,
10 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
11 process::Command,
12};
13use std::{
14 collections::HashMap,
15 future::Future,
16 io::Write,
17 str::FromStr,
18 sync::{
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 Arc,
21 },
22};
23use std::{path::Path, process::Stdio};
24use util::TryFutureExt;
25
26pub use lsp_types::*;
27
28const JSON_RPC_VERSION: &'static str = "2.0";
29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
30
31type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
33
34pub struct LanguageServer {
35 next_id: AtomicUsize,
36 outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
37 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
38 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
39 executor: Arc<executor::Background>,
40 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
41 initialized: barrier::Receiver,
42 output_done_rx: Mutex<Option<barrier::Receiver>>,
43}
44
45pub struct Subscription {
46 method: &'static str,
47 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
48}
49
50#[derive(Serialize, Deserialize)]
51struct Request<'a, T> {
52 jsonrpc: &'a str,
53 id: usize,
54 method: &'a str,
55 params: T,
56}
57
58#[derive(Serialize, Deserialize)]
59struct AnyResponse<'a> {
60 id: usize,
61 #[serde(default)]
62 error: Option<Error>,
63 #[serde(borrow)]
64 result: Option<&'a RawValue>,
65}
66
67#[derive(Serialize, Deserialize)]
68struct Notification<'a, T> {
69 #[serde(borrow)]
70 jsonrpc: &'a str,
71 #[serde(borrow)]
72 method: &'a str,
73 params: T,
74}
75
76#[derive(Deserialize)]
77struct AnyNotification<'a> {
78 #[serde(borrow)]
79 method: &'a str,
80 #[serde(borrow)]
81 params: &'a RawValue,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct Error {
86 message: String,
87}
88
89impl LanguageServer {
90 pub fn new(
91 binary_path: &Path,
92 root_path: &Path,
93 background: Arc<executor::Background>,
94 ) -> Result<Arc<Self>> {
95 let mut server = Command::new(binary_path)
96 .stdin(Stdio::piped())
97 .stdout(Stdio::piped())
98 .stderr(Stdio::inherit())
99 .spawn()?;
100 let stdin = server.stdin.take().unwrap();
101 let stdout = server.stdout.take().unwrap();
102 Self::new_internal(stdin, stdout, root_path, background)
103 }
104
105 fn new_internal<Stdin, Stdout>(
106 stdin: Stdin,
107 stdout: Stdout,
108 root_path: &Path,
109 executor: Arc<executor::Background>,
110 ) -> Result<Arc<Self>>
111 where
112 Stdin: AsyncWrite + Unpin + Send + 'static,
113 Stdout: AsyncRead + Unpin + Send + 'static,
114 {
115 let mut stdin = BufWriter::new(stdin);
116 let mut stdout = BufReader::new(stdout);
117 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120 let input_task = executor.spawn(
121 {
122 let notification_handlers = notification_handlers.clone();
123 let response_handlers = response_handlers.clone();
124 async move {
125 let mut buffer = Vec::new();
126 loop {
127 buffer.clear();
128 stdout.read_until(b'\n', &mut buffer).await?;
129 stdout.read_until(b'\n', &mut buffer).await?;
130 let message_len: usize = std::str::from_utf8(&buffer)?
131 .strip_prefix(CONTENT_LEN_HEADER)
132 .ok_or_else(|| anyhow!("invalid header"))?
133 .trim_end()
134 .parse()?;
135
136 buffer.resize(message_len, 0);
137 stdout.read_exact(&mut buffer).await?;
138
139 if let Ok(AnyNotification { method, params }) =
140 serde_json::from_slice(&buffer)
141 {
142 if let Some(handler) = notification_handlers.write().get_mut(method) {
143 handler(params.get());
144 } else {
145 log::info!(
146 "unhandled notification {}:\n{}",
147 method,
148 serde_json::to_string_pretty(
149 &Value::from_str(params.get()).unwrap()
150 )
151 .unwrap()
152 );
153 }
154 } else if let Ok(AnyResponse { id, error, result }) =
155 serde_json::from_slice(&buffer)
156 {
157 if let Some(handler) = response_handlers.lock().remove(&id) {
158 if let Some(error) = error {
159 handler(Err(error));
160 } else if let Some(result) = result {
161 handler(Ok(result.get()));
162 } else {
163 handler(Ok("null"));
164 }
165 }
166 } else {
167 return Err(anyhow!(
168 "failed to deserialize message:\n{}",
169 std::str::from_utf8(&buffer)?
170 ));
171 }
172 }
173 }
174 }
175 .log_err(),
176 );
177 let (output_done_tx, output_done_rx) = barrier::channel();
178 let output_task = executor.spawn(
179 async move {
180 let mut content_len_buffer = Vec::new();
181 while let Ok(message) = outbound_rx.recv().await {
182 content_len_buffer.clear();
183 write!(content_len_buffer, "{}", message.len()).unwrap();
184 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
185 stdin.write_all(&content_len_buffer).await?;
186 stdin.write_all("\r\n\r\n".as_bytes()).await?;
187 stdin.write_all(&message).await?;
188 stdin.flush().await?;
189 }
190 drop(output_done_tx);
191 Ok(())
192 }
193 .log_err(),
194 );
195
196 let (initialized_tx, initialized_rx) = barrier::channel();
197 let this = Arc::new(Self {
198 notification_handlers,
199 response_handlers,
200 next_id: Default::default(),
201 outbound_tx: RwLock::new(Some(outbound_tx)),
202 executor: executor.clone(),
203 io_tasks: Mutex::new(Some((input_task, output_task))),
204 initialized: initialized_rx,
205 output_done_rx: Mutex::new(Some(output_done_rx)),
206 });
207
208 let root_uri =
209 lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
210 executor
211 .spawn({
212 let this = this.clone();
213 async move {
214 this.init(root_uri).log_err().await;
215 drop(initialized_tx);
216 }
217 })
218 .detach();
219
220 Ok(this)
221 }
222
223 async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
224 #[allow(deprecated)]
225 let params = lsp_types::InitializeParams {
226 process_id: Default::default(),
227 root_path: Default::default(),
228 root_uri: Some(root_uri),
229 initialization_options: Default::default(),
230 capabilities: lsp_types::ClientCapabilities {
231 experimental: Some(json!({
232 "serverStatusNotification": true,
233 })),
234 window: Some(lsp_types::WindowClientCapabilities {
235 work_done_progress: Some(true),
236 ..Default::default()
237 }),
238 ..Default::default()
239 },
240 trace: Default::default(),
241 workspace_folders: Default::default(),
242 client_info: Default::default(),
243 locale: Default::default(),
244 };
245
246 let this = self.clone();
247 let request = Self::request_internal::<lsp_types::request::Initialize>(
248 &this.next_id,
249 &this.response_handlers,
250 this.outbound_tx.read().as_ref(),
251 params,
252 );
253 request.await?;
254 Self::notify_internal::<lsp_types::notification::Initialized>(
255 this.outbound_tx.read().as_ref(),
256 lsp_types::InitializedParams {},
257 )?;
258 Ok(())
259 }
260
261 pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
262 if let Some(tasks) = self.io_tasks.lock().take() {
263 let response_handlers = self.response_handlers.clone();
264 let outbound_tx = self.outbound_tx.write().take();
265 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
266 let mut output_done = self.output_done_rx.lock().take().unwrap();
267 Some(async move {
268 Self::request_internal::<lsp_types::request::Shutdown>(
269 &next_id,
270 &response_handlers,
271 outbound_tx.as_ref(),
272 (),
273 )
274 .await?;
275 Self::notify_internal::<lsp_types::notification::Exit>(outbound_tx.as_ref(), ())?;
276 drop(outbound_tx);
277 output_done.recv().await;
278 drop(tasks);
279 Ok(())
280 })
281 } else {
282 None
283 }
284 }
285
286 pub fn on_notification<T, F>(&self, mut f: F) -> Subscription
287 where
288 T: lsp_types::notification::Notification,
289 F: 'static + Send + Sync + FnMut(T::Params),
290 {
291 let prev_handler = self.notification_handlers.write().insert(
292 T::METHOD,
293 Box::new(
294 move |notification| match serde_json::from_str(notification) {
295 Ok(notification) => f(notification),
296 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
297 },
298 ),
299 );
300
301 assert!(
302 prev_handler.is_none(),
303 "registered multiple handlers for the same notification"
304 );
305
306 Subscription {
307 method: T::METHOD,
308 notification_handlers: self.notification_handlers.clone(),
309 }
310 }
311
312 pub fn request<T: lsp_types::request::Request>(
313 self: Arc<Self>,
314 params: T::Params,
315 ) -> impl Future<Output = Result<T::Result>>
316 where
317 T::Result: 'static + Send,
318 {
319 let this = self.clone();
320 async move {
321 this.initialized.clone().recv().await;
322 Self::request_internal::<T>(
323 &this.next_id,
324 &this.response_handlers,
325 this.outbound_tx.read().as_ref(),
326 params,
327 )
328 .await
329 }
330 }
331
332 fn request_internal<T: lsp_types::request::Request>(
333 next_id: &AtomicUsize,
334 response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
335 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
336 params: T::Params,
337 ) -> impl 'static + Future<Output = Result<T::Result>>
338 where
339 T::Result: 'static + Send,
340 {
341 let id = next_id.fetch_add(1, SeqCst);
342 let message = serde_json::to_vec(&Request {
343 jsonrpc: JSON_RPC_VERSION,
344 id,
345 method: T::METHOD,
346 params,
347 })
348 .unwrap();
349 let mut response_handlers = response_handlers.lock();
350 let (mut tx, mut rx) = oneshot::channel();
351 response_handlers.insert(
352 id,
353 Box::new(move |result| {
354 let response = match result {
355 Ok(response) => {
356 serde_json::from_str(response).context("failed to deserialize response")
357 }
358 Err(error) => Err(anyhow!("{}", error.message)),
359 };
360 let _ = tx.try_send(response);
361 }),
362 );
363
364 let send = outbound_tx
365 .as_ref()
366 .ok_or_else(|| {
367 anyhow!("tried to send a request to a language server that has been shut down")
368 })
369 .and_then(|outbound_tx| {
370 outbound_tx.try_send(message)?;
371 Ok(())
372 });
373 async move {
374 send?;
375 rx.recv().await.unwrap()
376 }
377 }
378
379 pub fn notify<T: lsp_types::notification::Notification>(
380 self: &Arc<Self>,
381 params: T::Params,
382 ) -> impl Future<Output = Result<()>> {
383 let this = self.clone();
384 async move {
385 this.initialized.clone().recv().await;
386 Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
387 Ok(())
388 }
389 }
390
391 fn notify_internal<T: lsp_types::notification::Notification>(
392 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
393 params: T::Params,
394 ) -> Result<()> {
395 let message = serde_json::to_vec(&Notification {
396 jsonrpc: JSON_RPC_VERSION,
397 method: T::METHOD,
398 params,
399 })
400 .unwrap();
401 let outbound_tx = outbound_tx
402 .as_ref()
403 .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
404 outbound_tx.try_send(message)?;
405 Ok(())
406 }
407}
408
409impl Drop for LanguageServer {
410 fn drop(&mut self) {
411 if let Some(shutdown) = self.shutdown() {
412 self.executor.spawn(shutdown).detach();
413 }
414 }
415}
416
417impl Subscription {
418 pub fn detach(mut self) {
419 self.method = "";
420 }
421}
422
423impl Drop for Subscription {
424 fn drop(&mut self) {
425 self.notification_handlers.write().remove(self.method);
426 }
427}
428
429#[cfg(any(test, feature = "test-support"))]
430pub struct FakeLanguageServer {
431 buffer: Vec<u8>,
432 stdin: smol::io::BufReader<async_pipe::PipeReader>,
433 stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
434 pub started: Arc<std::sync::atomic::AtomicBool>,
435}
436
437#[cfg(any(test, feature = "test-support"))]
438pub struct RequestId<T> {
439 id: usize,
440 _type: std::marker::PhantomData<T>,
441}
442
443#[cfg(any(test, feature = "test-support"))]
444impl LanguageServer {
445 pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
446 let stdin = async_pipe::pipe();
447 let stdout = async_pipe::pipe();
448 let mut fake = FakeLanguageServer {
449 stdin: smol::io::BufReader::new(stdin.1),
450 stdout: smol::io::BufWriter::new(stdout.0),
451 buffer: Vec::new(),
452 started: Arc::new(std::sync::atomic::AtomicBool::new(true)),
453 };
454
455 let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
456
457 let (init_id, _) = fake.receive_request::<request::Initialize>().await;
458 fake.respond(init_id, InitializeResult::default()).await;
459 fake.receive_notification::<notification::Initialized>()
460 .await;
461
462 (server, fake)
463 }
464}
465
466#[cfg(any(test, feature = "test-support"))]
467impl FakeLanguageServer {
468 pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
469 if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
470 panic!("can't simulate an LSP notification before the server has been started");
471 }
472 let message = serde_json::to_vec(&Notification {
473 jsonrpc: JSON_RPC_VERSION,
474 method: T::METHOD,
475 params,
476 })
477 .unwrap();
478 self.send(message).await;
479 }
480
481 pub async fn respond<'a, T: request::Request>(
482 &mut self,
483 request_id: RequestId<T>,
484 result: T::Result,
485 ) {
486 let result = serde_json::to_string(&result).unwrap();
487 let message = serde_json::to_vec(&AnyResponse {
488 id: request_id.id,
489 error: None,
490 result: Some(&RawValue::from_string(result).unwrap()),
491 })
492 .unwrap();
493 self.send(message).await;
494 }
495
496 pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
497 loop {
498 self.receive().await;
499 if let Ok(request) = serde_json::from_slice::<Request<T::Params>>(&self.buffer) {
500 assert_eq!(request.method, T::METHOD);
501 assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
502 return (
503 RequestId {
504 id: request.id,
505 _type: std::marker::PhantomData,
506 },
507 request.params,
508 );
509 } else {
510 println!(
511 "skipping message in fake language server {:?}",
512 std::str::from_utf8(&self.buffer)
513 );
514 }
515 }
516 }
517
518 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
519 self.receive().await;
520 let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
521 assert_eq!(notification.method, T::METHOD);
522 notification.params
523 }
524
525 pub async fn start_progress(&mut self, token: impl Into<String>) {
526 self.notify::<notification::Progress>(ProgressParams {
527 token: NumberOrString::String(token.into()),
528 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
529 })
530 .await;
531 }
532
533 pub async fn end_progress(&mut self, token: impl Into<String>) {
534 self.notify::<notification::Progress>(ProgressParams {
535 token: NumberOrString::String(token.into()),
536 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
537 })
538 .await;
539 }
540
541 async fn send(&mut self, message: Vec<u8>) {
542 self.stdout
543 .write_all(CONTENT_LEN_HEADER.as_bytes())
544 .await
545 .unwrap();
546 self.stdout
547 .write_all((format!("{}", message.len())).as_bytes())
548 .await
549 .unwrap();
550 self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
551 self.stdout.write_all(&message).await.unwrap();
552 self.stdout.flush().await.unwrap();
553 }
554
555 async fn receive(&mut self) {
556 self.buffer.clear();
557 self.stdin
558 .read_until(b'\n', &mut self.buffer)
559 .await
560 .unwrap();
561 self.stdin
562 .read_until(b'\n', &mut self.buffer)
563 .await
564 .unwrap();
565 let message_len: usize = std::str::from_utf8(&self.buffer)
566 .unwrap()
567 .strip_prefix(CONTENT_LEN_HEADER)
568 .unwrap()
569 .trim_end()
570 .parse()
571 .unwrap();
572 self.buffer.resize(message_len, 0);
573 self.stdin.read_exact(&mut self.buffer).await.unwrap();
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use gpui::TestAppContext;
581 use simplelog::SimpleLogger;
582 use unindent::Unindent;
583 use util::test::temp_tree;
584
585 #[gpui::test]
586 async fn test_basic(cx: TestAppContext) {
587 let lib_source = r#"
588 fn fun() {
589 let hello = "world";
590 }
591 "#
592 .unindent();
593 let root_dir = temp_tree(json!({
594 "Cargo.toml": r#"
595 [package]
596 name = "temp"
597 version = "0.1.0"
598 edition = "2018"
599 "#.unindent(),
600 "src": {
601 "lib.rs": &lib_source
602 }
603 }));
604 let lib_file_uri =
605 lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
606
607 let server = cx.read(|cx| {
608 LanguageServer::new(
609 Path::new("rust-analyzer"),
610 root_dir.path(),
611 cx.background().clone(),
612 )
613 .unwrap()
614 });
615 server.next_idle_notification().await;
616
617 server
618 .notify::<lsp_types::notification::DidOpenTextDocument>(
619 lsp_types::DidOpenTextDocumentParams {
620 text_document: lsp_types::TextDocumentItem::new(
621 lib_file_uri.clone(),
622 "rust".to_string(),
623 0,
624 lib_source,
625 ),
626 },
627 )
628 .await
629 .unwrap();
630
631 let hover = server
632 .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
633 text_document_position_params: lsp_types::TextDocumentPositionParams {
634 text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
635 position: lsp_types::Position::new(1, 21),
636 },
637 work_done_progress_params: Default::default(),
638 })
639 .await
640 .unwrap()
641 .unwrap();
642 assert_eq!(
643 hover.contents,
644 lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
645 kind: lsp_types::MarkupKind::Markdown,
646 value: "&str".to_string()
647 })
648 );
649 }
650
651 #[gpui::test]
652 async fn test_fake(cx: TestAppContext) {
653 SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
654
655 let (server, mut fake) = LanguageServer::fake(cx.background()).await;
656
657 let (message_tx, message_rx) = channel::unbounded();
658 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
659 server
660 .on_notification::<notification::ShowMessage, _>(move |params| {
661 message_tx.try_send(params).unwrap()
662 })
663 .detach();
664 server
665 .on_notification::<notification::PublishDiagnostics, _>(move |params| {
666 diagnostics_tx.try_send(params).unwrap()
667 })
668 .detach();
669
670 server
671 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
672 text_document: TextDocumentItem::new(
673 Url::from_str("file://a/b").unwrap(),
674 "rust".to_string(),
675 0,
676 "".to_string(),
677 ),
678 })
679 .await
680 .unwrap();
681 assert_eq!(
682 fake.receive_notification::<notification::DidOpenTextDocument>()
683 .await
684 .text_document
685 .uri
686 .as_str(),
687 "file://a/b"
688 );
689
690 fake.notify::<notification::ShowMessage>(ShowMessageParams {
691 typ: MessageType::ERROR,
692 message: "ok".to_string(),
693 })
694 .await;
695 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
696 uri: Url::from_str("file://b/c").unwrap(),
697 version: Some(5),
698 diagnostics: vec![],
699 })
700 .await;
701 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
702 assert_eq!(
703 diagnostics_rx.recv().await.unwrap().uri.as_str(),
704 "file://b/c"
705 );
706
707 drop(server);
708 let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
709 fake.respond(shutdown_request, ()).await;
710 fake.receive_notification::<lsp_types::notification::Exit>()
711 .await;
712 }
713
714 impl LanguageServer {
715 async fn next_idle_notification(self: &Arc<Self>) {
716 let (tx, rx) = channel::unbounded();
717 let _subscription =
718 self.on_notification::<ServerStatusNotification, _>(move |params| {
719 if params.quiescent {
720 tx.try_send(()).unwrap();
721 }
722 });
723 let _ = rx.recv().await;
724 }
725 }
726
727 pub enum ServerStatusNotification {}
728
729 impl lsp_types::notification::Notification for ServerStatusNotification {
730 type Params = ServerStatusParams;
731 const METHOD: &'static str = "experimental/serverStatus";
732 }
733
734 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
735 pub struct ServerStatusParams {
736 pub quiescent: bool,
737 }
738}