1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
3use gpui::{executor, Task};
4use parking_lot::{Mutex, RwLock};
5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, value::RawValue, Value};
8use smol::{
9 channel,
10 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
11 process::Command,
12};
13use std::{
14 collections::HashMap,
15 future::Future,
16 io::Write,
17 str::FromStr,
18 sync::{
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 Arc,
21 },
22};
23use std::{path::Path, process::Stdio};
24use util::TryFutureExt;
25
26pub use lsp_types::*;
27
28const JSON_RPC_VERSION: &'static str = "2.0";
29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
30
31type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
33
34pub struct LanguageServer {
35 next_id: AtomicUsize,
36 outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
37 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
38 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
39 executor: Arc<executor::Background>,
40 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
41 initialized: barrier::Receiver,
42 output_done_rx: Mutex<Option<barrier::Receiver>>,
43}
44
45pub struct Subscription {
46 method: &'static str,
47 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
48}
49
50#[derive(Serialize, Deserialize)]
51struct Request<'a, T> {
52 jsonrpc: &'a str,
53 id: usize,
54 method: &'a str,
55 params: T,
56}
57
58#[derive(Serialize, Deserialize)]
59struct AnyResponse<'a> {
60 id: usize,
61 #[serde(default)]
62 error: Option<Error>,
63 #[serde(borrow)]
64 result: Option<&'a RawValue>,
65}
66
67#[derive(Serialize, Deserialize)]
68struct Notification<'a, T> {
69 #[serde(borrow)]
70 jsonrpc: &'a str,
71 #[serde(borrow)]
72 method: &'a str,
73 params: T,
74}
75
76#[derive(Deserialize)]
77struct AnyNotification<'a> {
78 #[serde(borrow)]
79 method: &'a str,
80 #[serde(borrow)]
81 params: &'a RawValue,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct Error {
86 message: String,
87}
88
89impl LanguageServer {
90 pub fn new(
91 binary_path: &Path,
92 root_path: &Path,
93 background: Arc<executor::Background>,
94 ) -> Result<Arc<Self>> {
95 let mut server = Command::new(binary_path)
96 .stdin(Stdio::piped())
97 .stdout(Stdio::piped())
98 .stderr(Stdio::inherit())
99 .spawn()?;
100 let stdin = server.stdin.take().unwrap();
101 let stdout = server.stdout.take().unwrap();
102 Self::new_internal(stdin, stdout, root_path, background)
103 }
104
105 fn new_internal<Stdin, Stdout>(
106 stdin: Stdin,
107 stdout: Stdout,
108 root_path: &Path,
109 executor: Arc<executor::Background>,
110 ) -> Result<Arc<Self>>
111 where
112 Stdin: AsyncWrite + Unpin + Send + 'static,
113 Stdout: AsyncRead + Unpin + Send + 'static,
114 {
115 let mut stdin = BufWriter::new(stdin);
116 let mut stdout = BufReader::new(stdout);
117 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120 let input_task = executor.spawn(
121 {
122 let notification_handlers = notification_handlers.clone();
123 let response_handlers = response_handlers.clone();
124 async move {
125 let mut buffer = Vec::new();
126 loop {
127 buffer.clear();
128 stdout.read_until(b'\n', &mut buffer).await?;
129 stdout.read_until(b'\n', &mut buffer).await?;
130 let message_len: usize = std::str::from_utf8(&buffer)?
131 .strip_prefix(CONTENT_LEN_HEADER)
132 .ok_or_else(|| anyhow!("invalid header"))?
133 .trim_end()
134 .parse()?;
135
136 buffer.resize(message_len, 0);
137 stdout.read_exact(&mut buffer).await?;
138
139 if let Ok(AnyNotification { method, params }) =
140 serde_json::from_slice(&buffer)
141 {
142 if let Some(handler) = notification_handlers.write().get_mut(method) {
143 handler(params.get());
144 } else {
145 log::info!(
146 "unhandled notification {}:\n{}",
147 method,
148 serde_json::to_string_pretty(
149 &Value::from_str(params.get()).unwrap()
150 )
151 .unwrap()
152 );
153 }
154 } else if let Ok(AnyResponse { id, error, result }) =
155 serde_json::from_slice(&buffer)
156 {
157 if let Some(handler) = response_handlers.lock().remove(&id) {
158 if let Some(error) = error {
159 handler(Err(error));
160 } else if let Some(result) = result {
161 handler(Ok(result.get()));
162 } else {
163 handler(Ok("null"));
164 }
165 }
166 } else {
167 return Err(anyhow!(
168 "failed to deserialize message:\n{}",
169 std::str::from_utf8(&buffer)?
170 ));
171 }
172 }
173 }
174 }
175 .log_err(),
176 );
177 let (output_done_tx, output_done_rx) = barrier::channel();
178 let output_task = executor.spawn(
179 async move {
180 let mut content_len_buffer = Vec::new();
181 while let Ok(message) = outbound_rx.recv().await {
182 content_len_buffer.clear();
183 write!(content_len_buffer, "{}", message.len()).unwrap();
184 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
185 stdin.write_all(&content_len_buffer).await?;
186 stdin.write_all("\r\n\r\n".as_bytes()).await?;
187 stdin.write_all(&message).await?;
188 stdin.flush().await?;
189 }
190 drop(output_done_tx);
191 Ok(())
192 }
193 .log_err(),
194 );
195
196 let (initialized_tx, initialized_rx) = barrier::channel();
197 let this = Arc::new(Self {
198 notification_handlers,
199 response_handlers,
200 next_id: Default::default(),
201 outbound_tx: RwLock::new(Some(outbound_tx)),
202 executor: executor.clone(),
203 io_tasks: Mutex::new(Some((input_task, output_task))),
204 initialized: initialized_rx,
205 output_done_rx: Mutex::new(Some(output_done_rx)),
206 });
207
208 let root_uri = Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
209 executor
210 .spawn({
211 let this = this.clone();
212 async move {
213 this.init(root_uri).log_err().await;
214 drop(initialized_tx);
215 }
216 })
217 .detach();
218
219 Ok(this)
220 }
221
222 async fn init(self: Arc<Self>, root_uri: Url) -> Result<()> {
223 #[allow(deprecated)]
224 let params = InitializeParams {
225 process_id: Default::default(),
226 root_path: Default::default(),
227 root_uri: Some(root_uri),
228 initialization_options: Default::default(),
229 capabilities: ClientCapabilities {
230 text_document: Some(TextDocumentClientCapabilities {
231 definition: Some(GotoCapability {
232 link_support: Some(true),
233 ..Default::default()
234 }),
235 ..Default::default()
236 }),
237 experimental: Some(json!({
238 "serverStatusNotification": true,
239 })),
240 window: Some(WindowClientCapabilities {
241 work_done_progress: Some(true),
242 ..Default::default()
243 }),
244 ..Default::default()
245 },
246 trace: Default::default(),
247 workspace_folders: Default::default(),
248 client_info: Default::default(),
249 locale: Default::default(),
250 };
251
252 let this = self.clone();
253 let request = Self::request_internal::<request::Initialize>(
254 &this.next_id,
255 &this.response_handlers,
256 this.outbound_tx.read().as_ref(),
257 params,
258 );
259 request.await?;
260 Self::notify_internal::<notification::Initialized>(
261 this.outbound_tx.read().as_ref(),
262 InitializedParams {},
263 )?;
264 Ok(())
265 }
266
267 pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
268 if let Some(tasks) = self.io_tasks.lock().take() {
269 let response_handlers = self.response_handlers.clone();
270 let outbound_tx = self.outbound_tx.write().take();
271 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
272 let mut output_done = self.output_done_rx.lock().take().unwrap();
273 Some(async move {
274 Self::request_internal::<request::Shutdown>(
275 &next_id,
276 &response_handlers,
277 outbound_tx.as_ref(),
278 (),
279 )
280 .await?;
281 Self::notify_internal::<notification::Exit>(outbound_tx.as_ref(), ())?;
282 drop(outbound_tx);
283 output_done.recv().await;
284 drop(tasks);
285 Ok(())
286 })
287 } else {
288 None
289 }
290 }
291
292 pub fn on_notification<T, F>(&self, mut f: F) -> Subscription
293 where
294 T: notification::Notification,
295 F: 'static + Send + Sync + FnMut(T::Params),
296 {
297 let prev_handler = self.notification_handlers.write().insert(
298 T::METHOD,
299 Box::new(
300 move |notification| match serde_json::from_str(notification) {
301 Ok(notification) => f(notification),
302 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
303 },
304 ),
305 );
306
307 assert!(
308 prev_handler.is_none(),
309 "registered multiple handlers for the same notification"
310 );
311
312 Subscription {
313 method: T::METHOD,
314 notification_handlers: self.notification_handlers.clone(),
315 }
316 }
317
318 pub fn request<T: request::Request>(
319 self: &Arc<Self>,
320 params: T::Params,
321 ) -> impl Future<Output = Result<T::Result>>
322 where
323 T::Result: 'static + Send,
324 {
325 let this = self.clone();
326 async move {
327 this.initialized.clone().recv().await;
328 Self::request_internal::<T>(
329 &this.next_id,
330 &this.response_handlers,
331 this.outbound_tx.read().as_ref(),
332 params,
333 )
334 .await
335 }
336 }
337
338 fn request_internal<T: request::Request>(
339 next_id: &AtomicUsize,
340 response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
341 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
342 params: T::Params,
343 ) -> impl 'static + Future<Output = Result<T::Result>>
344 where
345 T::Result: 'static + Send,
346 {
347 let id = next_id.fetch_add(1, SeqCst);
348 let message = serde_json::to_vec(&Request {
349 jsonrpc: JSON_RPC_VERSION,
350 id,
351 method: T::METHOD,
352 params,
353 })
354 .unwrap();
355 let mut response_handlers = response_handlers.lock();
356 let (mut tx, mut rx) = oneshot::channel();
357 response_handlers.insert(
358 id,
359 Box::new(move |result| {
360 let response = match result {
361 Ok(response) => {
362 serde_json::from_str(response).context("failed to deserialize response")
363 }
364 Err(error) => Err(anyhow!("{}", error.message)),
365 };
366 let _ = tx.try_send(response);
367 }),
368 );
369
370 let send = outbound_tx
371 .as_ref()
372 .ok_or_else(|| {
373 anyhow!("tried to send a request to a language server that has been shut down")
374 })
375 .and_then(|outbound_tx| {
376 outbound_tx.try_send(message)?;
377 Ok(())
378 });
379 async move {
380 send?;
381 rx.recv().await.unwrap()
382 }
383 }
384
385 pub fn notify<T: notification::Notification>(
386 self: &Arc<Self>,
387 params: T::Params,
388 ) -> impl Future<Output = Result<()>> {
389 let this = self.clone();
390 async move {
391 this.initialized.clone().recv().await;
392 Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
393 Ok(())
394 }
395 }
396
397 fn notify_internal<T: notification::Notification>(
398 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
399 params: T::Params,
400 ) -> Result<()> {
401 let message = serde_json::to_vec(&Notification {
402 jsonrpc: JSON_RPC_VERSION,
403 method: T::METHOD,
404 params,
405 })
406 .unwrap();
407 let outbound_tx = outbound_tx
408 .as_ref()
409 .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
410 outbound_tx.try_send(message)?;
411 Ok(())
412 }
413}
414
415impl Drop for LanguageServer {
416 fn drop(&mut self) {
417 if let Some(shutdown) = self.shutdown() {
418 self.executor.spawn(shutdown).detach();
419 }
420 }
421}
422
423impl Subscription {
424 pub fn detach(mut self) {
425 self.method = "";
426 }
427}
428
429impl Drop for Subscription {
430 fn drop(&mut self) {
431 self.notification_handlers.write().remove(self.method);
432 }
433}
434
435#[cfg(any(test, feature = "test-support"))]
436pub struct FakeLanguageServer {
437 buffer: Vec<u8>,
438 stdin: smol::io::BufReader<async_pipe::PipeReader>,
439 stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
440 pub started: Arc<std::sync::atomic::AtomicBool>,
441}
442
443#[cfg(any(test, feature = "test-support"))]
444pub struct RequestId<T> {
445 id: usize,
446 _type: std::marker::PhantomData<T>,
447}
448
449#[cfg(any(test, feature = "test-support"))]
450impl LanguageServer {
451 pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
452 let stdin = async_pipe::pipe();
453 let stdout = async_pipe::pipe();
454 let mut fake = FakeLanguageServer {
455 stdin: smol::io::BufReader::new(stdin.1),
456 stdout: smol::io::BufWriter::new(stdout.0),
457 buffer: Vec::new(),
458 started: Arc::new(std::sync::atomic::AtomicBool::new(true)),
459 };
460
461 let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
462
463 let (init_id, _) = fake.receive_request::<request::Initialize>().await;
464 fake.respond(init_id, InitializeResult::default()).await;
465 fake.receive_notification::<notification::Initialized>()
466 .await;
467
468 (server, fake)
469 }
470}
471
472#[cfg(any(test, feature = "test-support"))]
473impl FakeLanguageServer {
474 pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
475 if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
476 panic!("can't simulate an LSP notification before the server has been started");
477 }
478 let message = serde_json::to_vec(&Notification {
479 jsonrpc: JSON_RPC_VERSION,
480 method: T::METHOD,
481 params,
482 })
483 .unwrap();
484 self.send(message).await;
485 }
486
487 pub async fn respond<'a, T: request::Request>(
488 &mut self,
489 request_id: RequestId<T>,
490 result: T::Result,
491 ) {
492 let result = serde_json::to_string(&result).unwrap();
493 let message = serde_json::to_vec(&AnyResponse {
494 id: request_id.id,
495 error: None,
496 result: Some(&RawValue::from_string(result).unwrap()),
497 })
498 .unwrap();
499 self.send(message).await;
500 }
501
502 pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
503 loop {
504 self.receive().await;
505 if let Ok(request) = serde_json::from_slice::<Request<T::Params>>(&self.buffer) {
506 assert_eq!(request.method, T::METHOD);
507 assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
508 return (
509 RequestId {
510 id: request.id,
511 _type: std::marker::PhantomData,
512 },
513 request.params,
514 );
515 } else {
516 println!(
517 "skipping message in fake language server {:?}",
518 std::str::from_utf8(&self.buffer)
519 );
520 }
521 }
522 }
523
524 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
525 self.receive().await;
526 let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
527 assert_eq!(notification.method, T::METHOD);
528 notification.params
529 }
530
531 pub async fn start_progress(&mut self, token: impl Into<String>) {
532 self.notify::<notification::Progress>(ProgressParams {
533 token: NumberOrString::String(token.into()),
534 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
535 })
536 .await;
537 }
538
539 pub async fn end_progress(&mut self, token: impl Into<String>) {
540 self.notify::<notification::Progress>(ProgressParams {
541 token: NumberOrString::String(token.into()),
542 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
543 })
544 .await;
545 }
546
547 async fn send(&mut self, message: Vec<u8>) {
548 self.stdout
549 .write_all(CONTENT_LEN_HEADER.as_bytes())
550 .await
551 .unwrap();
552 self.stdout
553 .write_all((format!("{}", message.len())).as_bytes())
554 .await
555 .unwrap();
556 self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
557 self.stdout.write_all(&message).await.unwrap();
558 self.stdout.flush().await.unwrap();
559 }
560
561 async fn receive(&mut self) {
562 self.buffer.clear();
563 self.stdin
564 .read_until(b'\n', &mut self.buffer)
565 .await
566 .unwrap();
567 self.stdin
568 .read_until(b'\n', &mut self.buffer)
569 .await
570 .unwrap();
571 let message_len: usize = std::str::from_utf8(&self.buffer)
572 .unwrap()
573 .strip_prefix(CONTENT_LEN_HEADER)
574 .unwrap()
575 .trim_end()
576 .parse()
577 .unwrap();
578 self.buffer.resize(message_len, 0);
579 self.stdin.read_exact(&mut self.buffer).await.unwrap();
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use gpui::TestAppContext;
587 use simplelog::SimpleLogger;
588 use unindent::Unindent;
589 use util::test::temp_tree;
590
591 #[gpui::test]
592 async fn test_basic(cx: TestAppContext) {
593 let lib_source = r#"
594 fn fun() {
595 let hello = "world";
596 }
597 "#
598 .unindent();
599 let root_dir = temp_tree(json!({
600 "Cargo.toml": r#"
601 [package]
602 name = "temp"
603 version = "0.1.0"
604 edition = "2018"
605 "#.unindent(),
606 "src": {
607 "lib.rs": &lib_source
608 }
609 }));
610 let lib_file_uri = Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
611
612 let server = cx.read(|cx| {
613 LanguageServer::new(
614 Path::new("rust-analyzer"),
615 root_dir.path(),
616 cx.background().clone(),
617 )
618 .unwrap()
619 });
620 server.next_idle_notification().await;
621
622 server
623 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
624 text_document: TextDocumentItem::new(
625 lib_file_uri.clone(),
626 "rust".to_string(),
627 0,
628 lib_source,
629 ),
630 })
631 .await
632 .unwrap();
633
634 let hover = server
635 .request::<request::HoverRequest>(HoverParams {
636 text_document_position_params: TextDocumentPositionParams {
637 text_document: TextDocumentIdentifier::new(lib_file_uri),
638 position: Position::new(1, 21),
639 },
640 work_done_progress_params: Default::default(),
641 })
642 .await
643 .unwrap()
644 .unwrap();
645 assert_eq!(
646 hover.contents,
647 HoverContents::Markup(MarkupContent {
648 kind: MarkupKind::Markdown,
649 value: "&str".to_string()
650 })
651 );
652 }
653
654 #[gpui::test]
655 async fn test_fake(cx: TestAppContext) {
656 SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
657
658 let (server, mut fake) = LanguageServer::fake(cx.background()).await;
659
660 let (message_tx, message_rx) = channel::unbounded();
661 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
662 server
663 .on_notification::<notification::ShowMessage, _>(move |params| {
664 message_tx.try_send(params).unwrap()
665 })
666 .detach();
667 server
668 .on_notification::<notification::PublishDiagnostics, _>(move |params| {
669 diagnostics_tx.try_send(params).unwrap()
670 })
671 .detach();
672
673 server
674 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
675 text_document: TextDocumentItem::new(
676 Url::from_str("file://a/b").unwrap(),
677 "rust".to_string(),
678 0,
679 "".to_string(),
680 ),
681 })
682 .await
683 .unwrap();
684 assert_eq!(
685 fake.receive_notification::<notification::DidOpenTextDocument>()
686 .await
687 .text_document
688 .uri
689 .as_str(),
690 "file://a/b"
691 );
692
693 fake.notify::<notification::ShowMessage>(ShowMessageParams {
694 typ: MessageType::ERROR,
695 message: "ok".to_string(),
696 })
697 .await;
698 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
699 uri: Url::from_str("file://b/c").unwrap(),
700 version: Some(5),
701 diagnostics: vec![],
702 })
703 .await;
704 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
705 assert_eq!(
706 diagnostics_rx.recv().await.unwrap().uri.as_str(),
707 "file://b/c"
708 );
709
710 drop(server);
711 let (shutdown_request, _) = fake.receive_request::<request::Shutdown>().await;
712 fake.respond(shutdown_request, ()).await;
713 fake.receive_notification::<notification::Exit>().await;
714 }
715
716 impl LanguageServer {
717 async fn next_idle_notification(self: &Arc<Self>) {
718 let (tx, rx) = channel::unbounded();
719 let _subscription =
720 self.on_notification::<ServerStatusNotification, _>(move |params| {
721 if params.quiescent {
722 tx.try_send(()).unwrap();
723 }
724 });
725 let _ = rx.recv().await;
726 }
727 }
728
729 pub enum ServerStatusNotification {}
730
731 impl notification::Notification for ServerStatusNotification {
732 type Params = ServerStatusParams;
733 const METHOD: &'static str = "experimental/serverStatus";
734 }
735
736 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
737 pub struct ServerStatusParams {
738 pub quiescent: bool,
739 }
740}