1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
3use gpui::{executor, Task};
4use parking_lot::{Mutex, RwLock, RwLockReadGuard};
5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, value::RawValue, Value};
8use smol::{
9 channel,
10 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
11 process::Command,
12};
13use std::{
14 collections::HashMap,
15 future::Future,
16 io::Write,
17 str::FromStr,
18 sync::{
19 atomic::{AtomicUsize, Ordering::SeqCst},
20 Arc,
21 },
22};
23use std::{path::Path, process::Stdio};
24use util::TryFutureExt;
25
26pub use lsp_types::*;
27
28const JSON_RPC_VERSION: &'static str = "2.0";
29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
30
31type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
33
34pub struct LanguageServer {
35 next_id: AtomicUsize,
36 outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
37 capabilities: RwLock<lsp_types::ServerCapabilities>,
38 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
39 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
40 executor: Arc<executor::Background>,
41 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
42 initialized: barrier::Receiver,
43 output_done_rx: Mutex<Option<barrier::Receiver>>,
44}
45
46pub struct Subscription {
47 method: &'static str,
48 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
49}
50
51#[derive(Serialize, Deserialize)]
52struct Request<'a, T> {
53 jsonrpc: &'a str,
54 id: usize,
55 method: &'a str,
56 params: T,
57}
58
59#[derive(Serialize, Deserialize)]
60struct AnyResponse<'a> {
61 id: usize,
62 #[serde(default)]
63 error: Option<Error>,
64 #[serde(borrow)]
65 result: Option<&'a RawValue>,
66}
67
68#[derive(Serialize, Deserialize)]
69struct Notification<'a, T> {
70 #[serde(borrow)]
71 jsonrpc: &'a str,
72 #[serde(borrow)]
73 method: &'a str,
74 params: T,
75}
76
77#[derive(Deserialize)]
78struct AnyNotification<'a> {
79 #[serde(borrow)]
80 method: &'a str,
81 #[serde(borrow)]
82 params: &'a RawValue,
83}
84
85#[derive(Debug, Serialize, Deserialize)]
86struct Error {
87 message: String,
88}
89
90impl LanguageServer {
91 pub fn new(
92 binary_path: &Path,
93 root_path: &Path,
94 background: Arc<executor::Background>,
95 ) -> Result<Arc<Self>> {
96 let mut server = Command::new(binary_path)
97 .stdin(Stdio::piped())
98 .stdout(Stdio::piped())
99 .stderr(Stdio::inherit())
100 .spawn()?;
101 let stdin = server.stdin.take().unwrap();
102 let stdout = server.stdout.take().unwrap();
103 Self::new_internal(stdin, stdout, root_path, background)
104 }
105
106 fn new_internal<Stdin, Stdout>(
107 stdin: Stdin,
108 stdout: Stdout,
109 root_path: &Path,
110 executor: Arc<executor::Background>,
111 ) -> Result<Arc<Self>>
112 where
113 Stdin: AsyncWrite + Unpin + Send + 'static,
114 Stdout: AsyncRead + Unpin + Send + 'static,
115 {
116 let mut stdin = BufWriter::new(stdin);
117 let mut stdout = BufReader::new(stdout);
118 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
119 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
120 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
121 let input_task = executor.spawn(
122 {
123 let notification_handlers = notification_handlers.clone();
124 let response_handlers = response_handlers.clone();
125 async move {
126 let mut buffer = Vec::new();
127 loop {
128 buffer.clear();
129 stdout.read_until(b'\n', &mut buffer).await?;
130 stdout.read_until(b'\n', &mut buffer).await?;
131 let message_len: usize = std::str::from_utf8(&buffer)?
132 .strip_prefix(CONTENT_LEN_HEADER)
133 .ok_or_else(|| anyhow!("invalid header"))?
134 .trim_end()
135 .parse()?;
136
137 buffer.resize(message_len, 0);
138 stdout.read_exact(&mut buffer).await?;
139
140 if let Ok(AnyNotification { method, params }) =
141 serde_json::from_slice(&buffer)
142 {
143 if let Some(handler) = notification_handlers.write().get_mut(method) {
144 handler(params.get());
145 } else {
146 log::info!(
147 "unhandled notification {}:\n{}",
148 method,
149 serde_json::to_string_pretty(
150 &Value::from_str(params.get()).unwrap()
151 )
152 .unwrap()
153 );
154 }
155 } else if let Ok(AnyResponse { id, error, result }) =
156 serde_json::from_slice(&buffer)
157 {
158 if let Some(handler) = response_handlers.lock().remove(&id) {
159 if let Some(error) = error {
160 handler(Err(error));
161 } else if let Some(result) = result {
162 handler(Ok(result.get()));
163 } else {
164 handler(Ok("null"));
165 }
166 }
167 } else {
168 return Err(anyhow!(
169 "failed to deserialize message:\n{}",
170 std::str::from_utf8(&buffer)?
171 ));
172 }
173 }
174 }
175 }
176 .log_err(),
177 );
178 let (output_done_tx, output_done_rx) = barrier::channel();
179 let output_task = executor.spawn(
180 async move {
181 let mut content_len_buffer = Vec::new();
182 while let Ok(message) = outbound_rx.recv().await {
183 content_len_buffer.clear();
184 write!(content_len_buffer, "{}", message.len()).unwrap();
185 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
186 stdin.write_all(&content_len_buffer).await?;
187 stdin.write_all("\r\n\r\n".as_bytes()).await?;
188 stdin.write_all(&message).await?;
189 stdin.flush().await?;
190 }
191 drop(output_done_tx);
192 Ok(())
193 }
194 .log_err(),
195 );
196
197 let (initialized_tx, initialized_rx) = barrier::channel();
198 let this = Arc::new(Self {
199 notification_handlers,
200 response_handlers,
201 capabilities: Default::default(),
202 next_id: Default::default(),
203 outbound_tx: RwLock::new(Some(outbound_tx)),
204 executor: executor.clone(),
205 io_tasks: Mutex::new(Some((input_task, output_task))),
206 initialized: initialized_rx,
207 output_done_rx: Mutex::new(Some(output_done_rx)),
208 });
209
210 let root_uri = Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
211 executor
212 .spawn({
213 let this = this.clone();
214 async move {
215 this.init(root_uri).log_err().await;
216 drop(initialized_tx);
217 }
218 })
219 .detach();
220
221 Ok(this)
222 }
223
224 async fn init(self: Arc<Self>, root_uri: Url) -> Result<()> {
225 #[allow(deprecated)]
226 let params = InitializeParams {
227 process_id: Default::default(),
228 root_path: Default::default(),
229 root_uri: Some(root_uri),
230 initialization_options: Default::default(),
231 capabilities: ClientCapabilities {
232 text_document: Some(TextDocumentClientCapabilities {
233 definition: Some(GotoCapability {
234 link_support: Some(true),
235 ..Default::default()
236 }),
237 completion: Some(CompletionClientCapabilities {
238 completion_item: Some(CompletionItemCapability {
239 resolve_support: Some(CompletionItemCapabilityResolveSupport {
240 properties: vec!["additionalTextEdits".to_string()],
241 }),
242 ..Default::default()
243 }),
244 ..Default::default()
245 }),
246 ..Default::default()
247 }),
248 experimental: Some(json!({
249 "serverStatusNotification": true,
250 })),
251 window: Some(WindowClientCapabilities {
252 work_done_progress: Some(true),
253 ..Default::default()
254 }),
255 ..Default::default()
256 },
257 trace: Default::default(),
258 workspace_folders: Default::default(),
259 client_info: Default::default(),
260 locale: Default::default(),
261 };
262
263 let this = self.clone();
264 let request = Self::request_internal::<request::Initialize>(
265 &this.next_id,
266 &this.response_handlers,
267 this.outbound_tx.read().as_ref(),
268 params,
269 );
270 let response = request.await?;
271 *this.capabilities.write() = response.capabilities;
272 Self::notify_internal::<notification::Initialized>(
273 this.outbound_tx.read().as_ref(),
274 InitializedParams {},
275 )?;
276 Ok(())
277 }
278
279 pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
280 if let Some(tasks) = self.io_tasks.lock().take() {
281 let response_handlers = self.response_handlers.clone();
282 let outbound_tx = self.outbound_tx.write().take();
283 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
284 let mut output_done = self.output_done_rx.lock().take().unwrap();
285 Some(async move {
286 Self::request_internal::<request::Shutdown>(
287 &next_id,
288 &response_handlers,
289 outbound_tx.as_ref(),
290 (),
291 )
292 .await?;
293 Self::notify_internal::<notification::Exit>(outbound_tx.as_ref(), ())?;
294 drop(outbound_tx);
295 output_done.recv().await;
296 drop(tasks);
297 Ok(())
298 })
299 } else {
300 None
301 }
302 }
303
304 pub fn on_notification<T, F>(&self, mut f: F) -> Subscription
305 where
306 T: notification::Notification,
307 F: 'static + Send + Sync + FnMut(T::Params),
308 {
309 let prev_handler = self.notification_handlers.write().insert(
310 T::METHOD,
311 Box::new(
312 move |notification| match serde_json::from_str(notification) {
313 Ok(notification) => f(notification),
314 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
315 },
316 ),
317 );
318
319 assert!(
320 prev_handler.is_none(),
321 "registered multiple handlers for the same notification"
322 );
323
324 Subscription {
325 method: T::METHOD,
326 notification_handlers: self.notification_handlers.clone(),
327 }
328 }
329
330 pub fn capabilities(&self) -> RwLockReadGuard<ServerCapabilities> {
331 self.capabilities.read()
332 }
333
334 pub fn request<T: request::Request>(
335 self: &Arc<Self>,
336 params: T::Params,
337 ) -> impl Future<Output = Result<T::Result>>
338 where
339 T::Result: 'static + Send,
340 {
341 let this = self.clone();
342 async move {
343 this.initialized.clone().recv().await;
344 Self::request_internal::<T>(
345 &this.next_id,
346 &this.response_handlers,
347 this.outbound_tx.read().as_ref(),
348 params,
349 )
350 .await
351 }
352 }
353
354 fn request_internal<T: request::Request>(
355 next_id: &AtomicUsize,
356 response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
357 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
358 params: T::Params,
359 ) -> impl 'static + Future<Output = Result<T::Result>>
360 where
361 T::Result: 'static + Send,
362 {
363 let id = next_id.fetch_add(1, SeqCst);
364 let message = serde_json::to_vec(&Request {
365 jsonrpc: JSON_RPC_VERSION,
366 id,
367 method: T::METHOD,
368 params,
369 })
370 .unwrap();
371 let mut response_handlers = response_handlers.lock();
372 let (mut tx, mut rx) = oneshot::channel();
373 response_handlers.insert(
374 id,
375 Box::new(move |result| {
376 let response = match result {
377 Ok(response) => {
378 serde_json::from_str(response).context("failed to deserialize response")
379 }
380 Err(error) => Err(anyhow!("{}", error.message)),
381 };
382 let _ = tx.try_send(response);
383 }),
384 );
385
386 let send = outbound_tx
387 .as_ref()
388 .ok_or_else(|| {
389 anyhow!("tried to send a request to a language server that has been shut down")
390 })
391 .and_then(|outbound_tx| {
392 outbound_tx.try_send(message)?;
393 Ok(())
394 });
395 async move {
396 send?;
397 rx.recv().await.unwrap()
398 }
399 }
400
401 pub fn notify<T: notification::Notification>(
402 self: &Arc<Self>,
403 params: T::Params,
404 ) -> impl Future<Output = Result<()>> {
405 let this = self.clone();
406 async move {
407 this.initialized.clone().recv().await;
408 Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
409 Ok(())
410 }
411 }
412
413 fn notify_internal<T: notification::Notification>(
414 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
415 params: T::Params,
416 ) -> Result<()> {
417 let message = serde_json::to_vec(&Notification {
418 jsonrpc: JSON_RPC_VERSION,
419 method: T::METHOD,
420 params,
421 })
422 .unwrap();
423 let outbound_tx = outbound_tx
424 .as_ref()
425 .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
426 outbound_tx.try_send(message)?;
427 Ok(())
428 }
429}
430
431impl Drop for LanguageServer {
432 fn drop(&mut self) {
433 if let Some(shutdown) = self.shutdown() {
434 self.executor.spawn(shutdown).detach();
435 }
436 }
437}
438
439impl Subscription {
440 pub fn detach(mut self) {
441 self.method = "";
442 }
443}
444
445impl Drop for Subscription {
446 fn drop(&mut self) {
447 self.notification_handlers.write().remove(self.method);
448 }
449}
450
451#[cfg(any(test, feature = "test-support"))]
452pub struct FakeLanguageServer {
453 buffer: Vec<u8>,
454 stdin: smol::io::BufReader<async_pipe::PipeReader>,
455 stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
456 pub started: Arc<std::sync::atomic::AtomicBool>,
457}
458
459#[cfg(any(test, feature = "test-support"))]
460pub struct RequestId<T> {
461 id: usize,
462 _type: std::marker::PhantomData<T>,
463}
464
465#[cfg(any(test, feature = "test-support"))]
466impl LanguageServer {
467 pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
468 Self::fake_with_capabilities(Default::default(), executor).await
469 }
470
471 pub async fn fake_with_capabilities(
472 capabilities: ServerCapabilities,
473 executor: Arc<executor::Background>,
474 ) -> (Arc<Self>, FakeLanguageServer) {
475 let stdin = async_pipe::pipe();
476 let stdout = async_pipe::pipe();
477 let mut fake = FakeLanguageServer {
478 stdin: smol::io::BufReader::new(stdin.1),
479 stdout: smol::io::BufWriter::new(stdout.0),
480 buffer: Vec::new(),
481 started: Arc::new(std::sync::atomic::AtomicBool::new(true)),
482 };
483
484 let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
485
486 let (init_id, _) = fake.receive_request::<request::Initialize>().await;
487 fake.respond(
488 init_id,
489 InitializeResult {
490 capabilities,
491 ..Default::default()
492 },
493 )
494 .await;
495 fake.receive_notification::<notification::Initialized>()
496 .await;
497
498 (server, fake)
499 }
500}
501
502#[cfg(any(test, feature = "test-support"))]
503impl FakeLanguageServer {
504 pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
505 if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
506 panic!("can't simulate an LSP notification before the server has been started");
507 }
508 let message = serde_json::to_vec(&Notification {
509 jsonrpc: JSON_RPC_VERSION,
510 method: T::METHOD,
511 params,
512 })
513 .unwrap();
514 self.send(message).await;
515 }
516
517 pub async fn respond<'a, T: request::Request>(
518 &mut self,
519 request_id: RequestId<T>,
520 result: T::Result,
521 ) {
522 let result = serde_json::to_string(&result).unwrap();
523 let message = serde_json::to_vec(&AnyResponse {
524 id: request_id.id,
525 error: None,
526 result: Some(&RawValue::from_string(result).unwrap()),
527 })
528 .unwrap();
529 self.send(message).await;
530 }
531
532 pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
533 loop {
534 self.receive().await;
535 if let Ok(request) = serde_json::from_slice::<Request<T::Params>>(&self.buffer) {
536 assert_eq!(request.method, T::METHOD);
537 assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
538 return (
539 RequestId {
540 id: request.id,
541 _type: std::marker::PhantomData,
542 },
543 request.params,
544 );
545 } else {
546 println!(
547 "skipping message in fake language server {:?}",
548 std::str::from_utf8(&self.buffer)
549 );
550 }
551 }
552 }
553
554 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
555 self.receive().await;
556 let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
557 assert_eq!(notification.method, T::METHOD);
558 notification.params
559 }
560
561 pub async fn start_progress(&mut self, token: impl Into<String>) {
562 self.notify::<notification::Progress>(ProgressParams {
563 token: NumberOrString::String(token.into()),
564 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
565 })
566 .await;
567 }
568
569 pub async fn end_progress(&mut self, token: impl Into<String>) {
570 self.notify::<notification::Progress>(ProgressParams {
571 token: NumberOrString::String(token.into()),
572 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
573 })
574 .await;
575 }
576
577 async fn send(&mut self, message: Vec<u8>) {
578 self.stdout
579 .write_all(CONTENT_LEN_HEADER.as_bytes())
580 .await
581 .unwrap();
582 self.stdout
583 .write_all((format!("{}", message.len())).as_bytes())
584 .await
585 .unwrap();
586 self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
587 self.stdout.write_all(&message).await.unwrap();
588 self.stdout.flush().await.unwrap();
589 }
590
591 async fn receive(&mut self) {
592 self.buffer.clear();
593 self.stdin
594 .read_until(b'\n', &mut self.buffer)
595 .await
596 .unwrap();
597 self.stdin
598 .read_until(b'\n', &mut self.buffer)
599 .await
600 .unwrap();
601 let message_len: usize = std::str::from_utf8(&self.buffer)
602 .unwrap()
603 .strip_prefix(CONTENT_LEN_HEADER)
604 .unwrap()
605 .trim_end()
606 .parse()
607 .unwrap();
608 self.buffer.resize(message_len, 0);
609 self.stdin.read_exact(&mut self.buffer).await.unwrap();
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use gpui::TestAppContext;
617 use simplelog::SimpleLogger;
618 use unindent::Unindent;
619 use util::test::temp_tree;
620
621 #[gpui::test]
622 async fn test_rust_analyzer(cx: TestAppContext) {
623 let lib_source = r#"
624 fn fun() {
625 let hello = "world";
626 }
627 "#
628 .unindent();
629 let root_dir = temp_tree(json!({
630 "Cargo.toml": r#"
631 [package]
632 name = "temp"
633 version = "0.1.0"
634 edition = "2018"
635 "#.unindent(),
636 "src": {
637 "lib.rs": &lib_source
638 }
639 }));
640 let lib_file_uri = Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
641
642 let server = cx.read(|cx| {
643 LanguageServer::new(
644 Path::new("rust-analyzer"),
645 root_dir.path(),
646 cx.background().clone(),
647 )
648 .unwrap()
649 });
650 server.next_idle_notification().await;
651
652 server
653 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
654 text_document: TextDocumentItem::new(
655 lib_file_uri.clone(),
656 "rust".to_string(),
657 0,
658 lib_source,
659 ),
660 })
661 .await
662 .unwrap();
663
664 let hover = server
665 .request::<request::HoverRequest>(HoverParams {
666 text_document_position_params: TextDocumentPositionParams {
667 text_document: TextDocumentIdentifier::new(lib_file_uri),
668 position: Position::new(1, 21),
669 },
670 work_done_progress_params: Default::default(),
671 })
672 .await
673 .unwrap()
674 .unwrap();
675 assert_eq!(
676 hover.contents,
677 HoverContents::Markup(MarkupContent {
678 kind: MarkupKind::PlainText,
679 value: "&str".to_string()
680 })
681 );
682 }
683
684 #[gpui::test]
685 async fn test_fake(cx: TestAppContext) {
686 SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
687
688 let (server, mut fake) = LanguageServer::fake(cx.background()).await;
689
690 let (message_tx, message_rx) = channel::unbounded();
691 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
692 server
693 .on_notification::<notification::ShowMessage, _>(move |params| {
694 message_tx.try_send(params).unwrap()
695 })
696 .detach();
697 server
698 .on_notification::<notification::PublishDiagnostics, _>(move |params| {
699 diagnostics_tx.try_send(params).unwrap()
700 })
701 .detach();
702
703 server
704 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
705 text_document: TextDocumentItem::new(
706 Url::from_str("file://a/b").unwrap(),
707 "rust".to_string(),
708 0,
709 "".to_string(),
710 ),
711 })
712 .await
713 .unwrap();
714 assert_eq!(
715 fake.receive_notification::<notification::DidOpenTextDocument>()
716 .await
717 .text_document
718 .uri
719 .as_str(),
720 "file://a/b"
721 );
722
723 fake.notify::<notification::ShowMessage>(ShowMessageParams {
724 typ: MessageType::ERROR,
725 message: "ok".to_string(),
726 })
727 .await;
728 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
729 uri: Url::from_str("file://b/c").unwrap(),
730 version: Some(5),
731 diagnostics: vec![],
732 })
733 .await;
734 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
735 assert_eq!(
736 diagnostics_rx.recv().await.unwrap().uri.as_str(),
737 "file://b/c"
738 );
739
740 drop(server);
741 let (shutdown_request, _) = fake.receive_request::<request::Shutdown>().await;
742 fake.respond(shutdown_request, ()).await;
743 fake.receive_notification::<notification::Exit>().await;
744 }
745
746 impl LanguageServer {
747 async fn next_idle_notification(self: &Arc<Self>) {
748 let (tx, rx) = channel::unbounded();
749 let _subscription =
750 self.on_notification::<ServerStatusNotification, _>(move |params| {
751 if params.quiescent {
752 tx.try_send(()).unwrap();
753 }
754 });
755 let _ = rx.recv().await;
756 }
757 }
758
759 pub enum ServerStatusNotification {}
760
761 impl notification::Notification for ServerStatusNotification {
762 type Params = ServerStatusParams;
763 const METHOD: &'static str = "experimental/serverStatus";
764 }
765
766 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
767 pub struct ServerStatusParams {
768 pub quiescent: bool,
769 }
770}