1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufWriter, AsyncRead, AsyncWrite};
3use gpui::{executor, Task};
4use parking_lot::{Mutex, RwLock};
5use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, value::RawValue, Value};
8use smol::{
9 channel,
10 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
11 process::Command,
12};
13use std::{
14 collections::HashMap,
15 future::Future,
16 io::Write,
17 str::FromStr,
18 sync::{
19 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
20 Arc,
21 },
22};
23use std::{path::Path, process::Stdio};
24use util::TryFutureExt;
25
26pub use lsp_types::*;
27
28const JSON_RPC_VERSION: &'static str = "2.0";
29const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
30
31type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
32type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
33
34pub struct LanguageServer {
35 next_id: AtomicUsize,
36 outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
37 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
38 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
39 executor: Arc<executor::Background>,
40 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
41 initialized: barrier::Receiver,
42 output_done_rx: Mutex<Option<barrier::Receiver>>,
43}
44
45pub struct Subscription {
46 method: &'static str,
47 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
48}
49
50#[derive(Serialize, Deserialize)]
51struct Request<'a, T> {
52 jsonrpc: &'a str,
53 id: usize,
54 method: &'a str,
55 params: T,
56}
57
58#[derive(Serialize, Deserialize)]
59struct AnyResponse<'a> {
60 id: usize,
61 #[serde(default)]
62 error: Option<Error>,
63 #[serde(borrow)]
64 result: Option<&'a RawValue>,
65}
66
67#[derive(Serialize, Deserialize)]
68struct Notification<'a, T> {
69 #[serde(borrow)]
70 jsonrpc: &'a str,
71 #[serde(borrow)]
72 method: &'a str,
73 params: T,
74}
75
76#[derive(Deserialize)]
77struct AnyNotification<'a> {
78 #[serde(borrow)]
79 method: &'a str,
80 #[serde(borrow)]
81 params: &'a RawValue,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct Error {
86 message: String,
87}
88
89impl LanguageServer {
90 pub fn new(
91 binary_path: &Path,
92 root_path: &Path,
93 background: Arc<executor::Background>,
94 ) -> Result<Arc<Self>> {
95 let mut server = Command::new(binary_path)
96 .stdin(Stdio::piped())
97 .stdout(Stdio::piped())
98 .stderr(Stdio::inherit())
99 .spawn()?;
100 let stdin = server.stdin.take().unwrap();
101 let stdout = server.stdout.take().unwrap();
102 Self::new_internal(stdin, stdout, root_path, background)
103 }
104
105 fn new_internal<Stdin, Stdout>(
106 stdin: Stdin,
107 stdout: Stdout,
108 root_path: &Path,
109 executor: Arc<executor::Background>,
110 ) -> Result<Arc<Self>>
111 where
112 Stdin: AsyncWrite + Unpin + Send + 'static,
113 Stdout: AsyncRead + Unpin + Send + 'static,
114 {
115 let mut stdin = BufWriter::new(stdin);
116 let mut stdout = BufReader::new(stdout);
117 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
118 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
119 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
120 let input_task = executor.spawn(
121 {
122 let notification_handlers = notification_handlers.clone();
123 let response_handlers = response_handlers.clone();
124 async move {
125 let mut buffer = Vec::new();
126 loop {
127 buffer.clear();
128 stdout.read_until(b'\n', &mut buffer).await?;
129 stdout.read_until(b'\n', &mut buffer).await?;
130 let message_len: usize = std::str::from_utf8(&buffer)?
131 .strip_prefix(CONTENT_LEN_HEADER)
132 .ok_or_else(|| anyhow!("invalid header"))?
133 .trim_end()
134 .parse()?;
135
136 buffer.resize(message_len, 0);
137 stdout.read_exact(&mut buffer).await?;
138
139 println!("{}", std::str::from_utf8(&buffer).unwrap());
140 if let Ok(AnyNotification { method, params }) =
141 serde_json::from_slice(&buffer)
142 {
143 if let Some(handler) = notification_handlers.read().get(method) {
144 handler(params.get());
145 } else {
146 log::info!(
147 "unhandled notification {}:\n{}",
148 method,
149 serde_json::to_string_pretty(
150 &Value::from_str(params.get()).unwrap()
151 )
152 .unwrap()
153 );
154 }
155 } else if let Ok(AnyResponse { id, error, result }) =
156 serde_json::from_slice(&buffer)
157 {
158 if let Some(handler) = response_handlers.lock().remove(&id) {
159 if let Some(error) = error {
160 handler(Err(error));
161 } else if let Some(result) = result {
162 handler(Ok(result.get()));
163 } else {
164 handler(Ok("null"));
165 }
166 }
167 } else {
168 return Err(anyhow!(
169 "failed to deserialize message:\n{}",
170 std::str::from_utf8(&buffer)?
171 ));
172 }
173 }
174 }
175 }
176 .log_err(),
177 );
178 let (output_done_tx, output_done_rx) = barrier::channel();
179 let output_task = executor.spawn(
180 async move {
181 let mut content_len_buffer = Vec::new();
182 while let Ok(message) = outbound_rx.recv().await {
183 content_len_buffer.clear();
184 write!(content_len_buffer, "{}", message.len()).unwrap();
185 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
186 stdin.write_all(&content_len_buffer).await?;
187 stdin.write_all("\r\n\r\n".as_bytes()).await?;
188 stdin.write_all(&message).await?;
189 stdin.flush().await?;
190 }
191 drop(output_done_tx);
192 Ok(())
193 }
194 .log_err(),
195 );
196
197 let (initialized_tx, initialized_rx) = barrier::channel();
198 let this = Arc::new(Self {
199 notification_handlers,
200 response_handlers,
201 next_id: Default::default(),
202 outbound_tx: RwLock::new(Some(outbound_tx)),
203 executor: executor.clone(),
204 io_tasks: Mutex::new(Some((input_task, output_task))),
205 initialized: initialized_rx,
206 output_done_rx: Mutex::new(Some(output_done_rx)),
207 });
208
209 let root_uri =
210 lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
211 executor
212 .spawn({
213 let this = this.clone();
214 async move {
215 this.init(root_uri).log_err().await;
216 drop(initialized_tx);
217 }
218 })
219 .detach();
220
221 Ok(this)
222 }
223
224 async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
225 #[allow(deprecated)]
226 let params = lsp_types::InitializeParams {
227 process_id: Default::default(),
228 root_path: Default::default(),
229 root_uri: Some(root_uri),
230 initialization_options: Default::default(),
231 capabilities: lsp_types::ClientCapabilities {
232 experimental: Some(json!({
233 "serverStatusNotification": true,
234 })),
235 ..Default::default()
236 },
237 trace: Default::default(),
238 workspace_folders: Default::default(),
239 client_info: Default::default(),
240 locale: Default::default(),
241 };
242
243 let this = self.clone();
244 let request = Self::request_internal::<lsp_types::request::Initialize>(
245 &this.next_id,
246 &this.response_handlers,
247 this.outbound_tx.read().as_ref(),
248 params,
249 );
250 request.await?;
251 Self::notify_internal::<lsp_types::notification::Initialized>(
252 this.outbound_tx.read().as_ref(),
253 lsp_types::InitializedParams {},
254 )?;
255 Ok(())
256 }
257
258 pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
259 if let Some(tasks) = self.io_tasks.lock().take() {
260 let response_handlers = self.response_handlers.clone();
261 let outbound_tx = self.outbound_tx.write().take();
262 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
263 let mut output_done = self.output_done_rx.lock().take().unwrap();
264 Some(async move {
265 Self::request_internal::<lsp_types::request::Shutdown>(
266 &next_id,
267 &response_handlers,
268 outbound_tx.as_ref(),
269 (),
270 )
271 .await?;
272 Self::notify_internal::<lsp_types::notification::Exit>(outbound_tx.as_ref(), ())?;
273 drop(outbound_tx);
274 output_done.recv().await;
275 drop(tasks);
276 Ok(())
277 })
278 } else {
279 None
280 }
281 }
282
283 pub fn on_notification<T, F>(&self, f: F) -> Subscription
284 where
285 T: lsp_types::notification::Notification,
286 F: 'static + Send + Sync + Fn(T::Params),
287 {
288 let prev_handler = self.notification_handlers.write().insert(
289 T::METHOD,
290 Box::new(
291 move |notification| match serde_json::from_str(notification) {
292 Ok(notification) => f(notification),
293 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
294 },
295 ),
296 );
297
298 assert!(
299 prev_handler.is_none(),
300 "registered multiple handlers for the same notification"
301 );
302
303 Subscription {
304 method: T::METHOD,
305 notification_handlers: self.notification_handlers.clone(),
306 }
307 }
308
309 pub fn request<T: lsp_types::request::Request>(
310 self: Arc<Self>,
311 params: T::Params,
312 ) -> impl Future<Output = Result<T::Result>>
313 where
314 T::Result: 'static + Send,
315 {
316 let this = self.clone();
317 async move {
318 this.initialized.clone().recv().await;
319 Self::request_internal::<T>(
320 &this.next_id,
321 &this.response_handlers,
322 this.outbound_tx.read().as_ref(),
323 params,
324 )
325 .await
326 }
327 }
328
329 fn request_internal<T: lsp_types::request::Request>(
330 next_id: &AtomicUsize,
331 response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
332 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
333 params: T::Params,
334 ) -> impl 'static + Future<Output = Result<T::Result>>
335 where
336 T::Result: 'static + Send,
337 {
338 let id = next_id.fetch_add(1, SeqCst);
339 let message = serde_json::to_vec(&Request {
340 jsonrpc: JSON_RPC_VERSION,
341 id,
342 method: T::METHOD,
343 params,
344 })
345 .unwrap();
346 let mut response_handlers = response_handlers.lock();
347 let (mut tx, mut rx) = oneshot::channel();
348 response_handlers.insert(
349 id,
350 Box::new(move |result| {
351 let response = match result {
352 Ok(response) => {
353 serde_json::from_str(response).context("failed to deserialize response")
354 }
355 Err(error) => Err(anyhow!("{}", error.message)),
356 };
357 let _ = tx.try_send(response);
358 }),
359 );
360
361 let send = outbound_tx
362 .as_ref()
363 .ok_or_else(|| {
364 anyhow!("tried to send a request to a language server that has been shut down")
365 })
366 .and_then(|outbound_tx| {
367 outbound_tx.try_send(message)?;
368 Ok(())
369 });
370 async move {
371 send?;
372 rx.recv().await.unwrap()
373 }
374 }
375
376 pub fn notify<T: lsp_types::notification::Notification>(
377 self: &Arc<Self>,
378 params: T::Params,
379 ) -> impl Future<Output = Result<()>> {
380 let this = self.clone();
381 async move {
382 this.initialized.clone().recv().await;
383 Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
384 Ok(())
385 }
386 }
387
388 fn notify_internal<T: lsp_types::notification::Notification>(
389 outbound_tx: Option<&channel::Sender<Vec<u8>>>,
390 params: T::Params,
391 ) -> Result<()> {
392 let message = serde_json::to_vec(&Notification {
393 jsonrpc: JSON_RPC_VERSION,
394 method: T::METHOD,
395 params,
396 })
397 .unwrap();
398 let outbound_tx = outbound_tx
399 .as_ref()
400 .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
401 outbound_tx.try_send(message)?;
402 Ok(())
403 }
404}
405
406impl Drop for LanguageServer {
407 fn drop(&mut self) {
408 if let Some(shutdown) = self.shutdown() {
409 self.executor.spawn(shutdown).detach();
410 }
411 }
412}
413
414impl Subscription {
415 pub fn detach(mut self) {
416 self.method = "";
417 }
418}
419
420impl Drop for Subscription {
421 fn drop(&mut self) {
422 self.notification_handlers.write().remove(self.method);
423 }
424}
425
426#[cfg(any(test, feature = "test-support"))]
427pub struct FakeLanguageServer {
428 buffer: Vec<u8>,
429 stdin: smol::io::BufReader<async_pipe::PipeReader>,
430 stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
431 pub started: Arc<AtomicBool>,
432}
433
434#[cfg(any(test, feature = "test-support"))]
435pub struct RequestId<T> {
436 id: usize,
437 _type: std::marker::PhantomData<T>,
438}
439
440#[cfg(any(test, feature = "test-support"))]
441impl LanguageServer {
442 pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
443 let stdin = async_pipe::pipe();
444 let stdout = async_pipe::pipe();
445 let mut fake = FakeLanguageServer {
446 stdin: smol::io::BufReader::new(stdin.1),
447 stdout: smol::io::BufWriter::new(stdout.0),
448 buffer: Vec::new(),
449 started: Arc::new(AtomicBool::new(true)),
450 };
451
452 let server = Self::new_internal(stdin.0, stdout.1, Path::new("/"), executor).unwrap();
453
454 let (init_id, _) = fake.receive_request::<request::Initialize>().await;
455 fake.respond(init_id, InitializeResult::default()).await;
456 fake.receive_notification::<notification::Initialized>()
457 .await;
458
459 (server, fake)
460 }
461}
462
463#[cfg(any(test, feature = "test-support"))]
464impl FakeLanguageServer {
465 pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
466 if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
467 panic!("can't simulate an LSP notification before the server has been started");
468 }
469 let message = serde_json::to_vec(&Notification {
470 jsonrpc: JSON_RPC_VERSION,
471 method: T::METHOD,
472 params,
473 })
474 .unwrap();
475 self.send(message).await;
476 }
477
478 pub async fn respond<'a, T: request::Request>(
479 &mut self,
480 request_id: RequestId<T>,
481 result: T::Result,
482 ) {
483 let result = serde_json::to_string(&result).unwrap();
484 let message = serde_json::to_vec(&AnyResponse {
485 id: request_id.id,
486 error: None,
487 result: Some(&RawValue::from_string(result).unwrap()),
488 })
489 .unwrap();
490 self.send(message).await;
491 }
492
493 pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
494 self.receive().await;
495 let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
496 assert_eq!(request.method, T::METHOD);
497 assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
498 (
499 RequestId {
500 id: request.id,
501 _type: std::marker::PhantomData,
502 },
503 request.params,
504 )
505 }
506
507 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
508 self.receive().await;
509 let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
510 assert_eq!(notification.method, T::METHOD);
511 notification.params
512 }
513
514 async fn send(&mut self, message: Vec<u8>) {
515 self.stdout
516 .write_all(CONTENT_LEN_HEADER.as_bytes())
517 .await
518 .unwrap();
519 self.stdout
520 .write_all((format!("{}", message.len())).as_bytes())
521 .await
522 .unwrap();
523 self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
524 self.stdout.write_all(&message).await.unwrap();
525 self.stdout.flush().await.unwrap();
526 }
527
528 async fn receive(&mut self) {
529 self.buffer.clear();
530 self.stdin
531 .read_until(b'\n', &mut self.buffer)
532 .await
533 .unwrap();
534 self.stdin
535 .read_until(b'\n', &mut self.buffer)
536 .await
537 .unwrap();
538 let message_len: usize = std::str::from_utf8(&self.buffer)
539 .unwrap()
540 .strip_prefix(CONTENT_LEN_HEADER)
541 .unwrap()
542 .trim_end()
543 .parse()
544 .unwrap();
545 self.buffer.resize(message_len, 0);
546 self.stdin.read_exact(&mut self.buffer).await.unwrap();
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use gpui::TestAppContext;
554 use simplelog::SimpleLogger;
555 use unindent::Unindent;
556 use util::test::temp_tree;
557
558 #[gpui::test]
559 async fn test_basic(cx: TestAppContext) {
560 let lib_source = r#"
561 fn fun() {
562 let hello = "world";
563 }
564 "#
565 .unindent();
566 let root_dir = temp_tree(json!({
567 "Cargo.toml": r#"
568 [package]
569 name = "temp"
570 version = "0.1.0"
571 edition = "2018"
572 "#.unindent(),
573 "src": {
574 "lib.rs": &lib_source
575 }
576 }));
577 let lib_file_uri =
578 lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
579
580 let server = cx.read(|cx| {
581 LanguageServer::new(
582 Path::new("rust-analyzer"),
583 root_dir.path(),
584 cx.background().clone(),
585 )
586 .unwrap()
587 });
588 server.next_idle_notification().await;
589
590 server
591 .notify::<lsp_types::notification::DidOpenTextDocument>(
592 lsp_types::DidOpenTextDocumentParams {
593 text_document: lsp_types::TextDocumentItem::new(
594 lib_file_uri.clone(),
595 "rust".to_string(),
596 0,
597 lib_source,
598 ),
599 },
600 )
601 .await
602 .unwrap();
603
604 let hover = server
605 .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
606 text_document_position_params: lsp_types::TextDocumentPositionParams {
607 text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
608 position: lsp_types::Position::new(1, 21),
609 },
610 work_done_progress_params: Default::default(),
611 })
612 .await
613 .unwrap()
614 .unwrap();
615 assert_eq!(
616 hover.contents,
617 lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
618 kind: lsp_types::MarkupKind::Markdown,
619 value: "&str".to_string()
620 })
621 );
622 }
623
624 #[gpui::test]
625 async fn test_fake(cx: TestAppContext) {
626 SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
627
628 let (server, mut fake) = LanguageServer::fake(cx.background()).await;
629
630 let (message_tx, message_rx) = channel::unbounded();
631 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
632 server
633 .on_notification::<notification::ShowMessage, _>(move |params| {
634 message_tx.try_send(params).unwrap()
635 })
636 .detach();
637 server
638 .on_notification::<notification::PublishDiagnostics, _>(move |params| {
639 diagnostics_tx.try_send(params).unwrap()
640 })
641 .detach();
642
643 server
644 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
645 text_document: TextDocumentItem::new(
646 Url::from_str("file://a/b").unwrap(),
647 "rust".to_string(),
648 0,
649 "".to_string(),
650 ),
651 })
652 .await
653 .unwrap();
654 assert_eq!(
655 fake.receive_notification::<notification::DidOpenTextDocument>()
656 .await
657 .text_document
658 .uri
659 .as_str(),
660 "file://a/b"
661 );
662
663 fake.notify::<notification::ShowMessage>(ShowMessageParams {
664 typ: MessageType::ERROR,
665 message: "ok".to_string(),
666 })
667 .await;
668 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
669 uri: Url::from_str("file://b/c").unwrap(),
670 version: Some(5),
671 diagnostics: vec![],
672 })
673 .await;
674 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
675 assert_eq!(
676 diagnostics_rx.recv().await.unwrap().uri.as_str(),
677 "file://b/c"
678 );
679
680 drop(server);
681 let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
682 fake.respond(shutdown_request, ()).await;
683 fake.receive_notification::<lsp_types::notification::Exit>()
684 .await;
685 }
686
687 impl LanguageServer {
688 async fn next_idle_notification(self: &Arc<Self>) {
689 let (tx, rx) = channel::unbounded();
690 let _subscription =
691 self.on_notification::<ServerStatusNotification, _>(move |params| {
692 if params.quiescent {
693 tx.try_send(()).unwrap();
694 }
695 });
696 let _ = rx.recv().await;
697 }
698 }
699
700 pub enum ServerStatusNotification {}
701
702 impl lsp_types::notification::Notification for ServerStatusNotification {
703 type Params = ServerStatusParams;
704 const METHOD: &'static str = "experimental/serverStatus";
705 }
706
707 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
708 pub struct ServerStatusParams {
709 pub quiescent: bool,
710 }
711}