1use log::warn;
2pub use lsp_types::request::*;
3pub use lsp_types::*;
4
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
8use gpui::{executor, AsyncAppContext, Task};
9use parking_lot::Mutex;
10use postage::{barrier, prelude::Stream};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use serde_json::{json, value::RawValue, Value};
13use smol::{
14 channel,
15 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
16 process::{self, Child},
17};
18use std::{
19 future::Future,
20 io::Write,
21 path::PathBuf,
22 str::FromStr,
23 sync::{
24 atomic::{AtomicUsize, Ordering::SeqCst},
25 Arc,
26 },
27};
28use std::{path::Path, process::Stdio};
29use util::{ResultExt, TryFutureExt};
30
31const JSON_RPC_VERSION: &str = "2.0";
32const CONTENT_LEN_HEADER: &str = "Content-Length: ";
33
34type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
35type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
36
37pub struct LanguageServer {
38 server_id: usize,
39 next_id: AtomicUsize,
40 outbound_tx: channel::Sender<Vec<u8>>,
41 name: String,
42 capabilities: ServerCapabilities,
43 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
44 response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
45 executor: Arc<executor::Background>,
46 #[allow(clippy::type_complexity)]
47 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
48 output_done_rx: Mutex<Option<barrier::Receiver>>,
49 root_path: PathBuf,
50 _server: Option<Child>,
51}
52
53pub struct Subscription {
54 method: &'static str,
55 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
56}
57
58#[derive(Serialize, Deserialize)]
59struct Request<'a, T> {
60 jsonrpc: &'static str,
61 id: usize,
62 method: &'a str,
63 params: T,
64}
65
66#[derive(Serialize, Deserialize)]
67struct AnyResponse<'a> {
68 jsonrpc: &'a str,
69 id: usize,
70 #[serde(default)]
71 error: Option<Error>,
72 #[serde(borrow)]
73 result: Option<&'a RawValue>,
74}
75
76#[derive(Serialize)]
77struct Response<T> {
78 jsonrpc: &'static str,
79 id: usize,
80 result: Option<T>,
81 error: Option<Error>,
82}
83
84#[derive(Serialize, Deserialize)]
85struct Notification<'a, T> {
86 jsonrpc: &'static str,
87 #[serde(borrow)]
88 method: &'a str,
89 params: T,
90}
91
92#[derive(Deserialize)]
93struct AnyNotification<'a> {
94 #[serde(default)]
95 id: Option<usize>,
96 #[serde(borrow)]
97 method: &'a str,
98 #[serde(borrow)]
99 params: &'a RawValue,
100}
101
102#[derive(Debug, Serialize, Deserialize)]
103struct Error {
104 message: String,
105}
106
107impl LanguageServer {
108 pub fn new<T: AsRef<std::ffi::OsStr>>(
109 server_id: usize,
110 binary_path: &Path,
111 arguments: &[T],
112 root_path: &Path,
113 cx: AsyncAppContext,
114 ) -> Result<Self> {
115 let working_dir = if root_path.is_dir() {
116 root_path
117 } else {
118 root_path.parent().unwrap_or_else(|| Path::new("/"))
119 };
120
121 let mut server = process::Command::new(binary_path)
122 .current_dir(working_dir)
123 .args(arguments)
124 .stdin(Stdio::piped())
125 .stdout(Stdio::piped())
126 .stderr(Stdio::inherit())
127 .kill_on_drop(true)
128 .spawn()?;
129
130 let stdin = server.stdin.take().unwrap();
131 let stout = server.stdout.take().unwrap();
132 let mut server = Self::new_internal(
133 server_id,
134 stdin,
135 stout,
136 Some(server),
137 root_path,
138 cx,
139 |notification| {
140 log::info!(
141 "unhandled notification {}:\n{}",
142 notification.method,
143 serde_json::to_string_pretty(
144 &Value::from_str(notification.params.get()).unwrap()
145 )
146 .unwrap()
147 );
148 },
149 );
150
151 if let Some(name) = binary_path.file_name() {
152 server.name = name.to_string_lossy().to_string();
153 }
154 Ok(server)
155 }
156
157 fn new_internal<Stdin, Stdout, F>(
158 server_id: usize,
159 stdin: Stdin,
160 stdout: Stdout,
161 server: Option<Child>,
162 root_path: &Path,
163 cx: AsyncAppContext,
164 on_unhandled_notification: F,
165 ) -> Self
166 where
167 Stdin: AsyncWrite + Unpin + Send + 'static,
168 Stdout: AsyncRead + Unpin + Send + 'static,
169 F: FnMut(AnyNotification) + 'static + Send,
170 {
171 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
172 let notification_handlers =
173 Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
174 let response_handlers =
175 Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
176 let input_task = cx.spawn(|cx| {
177 let notification_handlers = notification_handlers.clone();
178 let response_handlers = response_handlers.clone();
179 Self::handle_input(
180 stdout,
181 on_unhandled_notification,
182 notification_handlers,
183 response_handlers,
184 cx,
185 )
186 .log_err()
187 });
188 let (output_done_tx, output_done_rx) = barrier::channel();
189 let output_task = cx.background().spawn({
190 let response_handlers = response_handlers.clone();
191 Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
192 });
193
194 Self {
195 server_id,
196 notification_handlers,
197 response_handlers,
198 name: Default::default(),
199 capabilities: Default::default(),
200 next_id: Default::default(),
201 outbound_tx,
202 executor: cx.background(),
203 io_tasks: Mutex::new(Some((input_task, output_task))),
204 output_done_rx: Mutex::new(Some(output_done_rx)),
205 root_path: root_path.to_path_buf(),
206 _server: server,
207 }
208 }
209
210 async fn handle_input<Stdout, F>(
211 stdout: Stdout,
212 mut on_unhandled_notification: F,
213 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
214 response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
215 cx: AsyncAppContext,
216 ) -> anyhow::Result<()>
217 where
218 Stdout: AsyncRead + Unpin + Send + 'static,
219 F: FnMut(AnyNotification) + 'static + Send,
220 {
221 let mut stdout = BufReader::new(stdout);
222 let _clear_response_handlers = util::defer({
223 let response_handlers = response_handlers.clone();
224 move || {
225 response_handlers.lock().take();
226 }
227 });
228 let mut buffer = Vec::new();
229 loop {
230 buffer.clear();
231 stdout.read_until(b'\n', &mut buffer).await?;
232 stdout.read_until(b'\n', &mut buffer).await?;
233 let message_len: usize = std::str::from_utf8(&buffer)?
234 .strip_prefix(CONTENT_LEN_HEADER)
235 .ok_or_else(|| anyhow!("invalid header"))?
236 .trim_end()
237 .parse()?;
238
239 buffer.resize(message_len, 0);
240 stdout.read_exact(&mut buffer).await?;
241 log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
242
243 if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
244 if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
245 handler(msg.id, msg.params.get(), cx.clone());
246 } else {
247 on_unhandled_notification(msg);
248 }
249 } else if let Ok(AnyResponse {
250 id, error, result, ..
251 }) = serde_json::from_slice(&buffer)
252 {
253 if let Some(handler) = response_handlers
254 .lock()
255 .as_mut()
256 .and_then(|handlers| handlers.remove(&id))
257 {
258 if let Some(error) = error {
259 handler(Err(error));
260 } else if let Some(result) = result {
261 handler(Ok(result.get()));
262 } else {
263 handler(Ok("null"));
264 }
265 }
266 } else {
267 warn!(
268 "Failed to deserialize message:\n{}",
269 std::str::from_utf8(&buffer)?
270 );
271 }
272
273 // Don't starve the main thread when receiving lots of messages at once.
274 smol::future::yield_now().await;
275 }
276 }
277
278 async fn handle_output<Stdin>(
279 stdin: Stdin,
280 outbound_rx: channel::Receiver<Vec<u8>>,
281 output_done_tx: barrier::Sender,
282 response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
283 ) -> anyhow::Result<()>
284 where
285 Stdin: AsyncWrite + Unpin + Send + 'static,
286 {
287 let mut stdin = BufWriter::new(stdin);
288 let _clear_response_handlers = util::defer({
289 let response_handlers = response_handlers.clone();
290 move || {
291 response_handlers.lock().take();
292 }
293 });
294 let mut content_len_buffer = Vec::new();
295 while let Ok(message) = outbound_rx.recv().await {
296 log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
297 content_len_buffer.clear();
298 write!(content_len_buffer, "{}", message.len()).unwrap();
299 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
300 stdin.write_all(&content_len_buffer).await?;
301 stdin.write_all("\r\n\r\n".as_bytes()).await?;
302 stdin.write_all(&message).await?;
303 stdin.flush().await?;
304 }
305 drop(output_done_tx);
306 Ok(())
307 }
308
309 /// Initializes a language server.
310 /// Note that `options` is used directly to construct [`InitializeParams`],
311 /// which is why it is owned.
312 pub async fn initialize(mut self, options: Option<Value>) -> Result<Arc<Self>> {
313 let root_uri = Url::from_file_path(&self.root_path).unwrap();
314 #[allow(deprecated)]
315 let params = InitializeParams {
316 process_id: Default::default(),
317 root_path: Default::default(),
318 root_uri: Some(root_uri.clone()),
319 initialization_options: options,
320 capabilities: ClientCapabilities {
321 workspace: Some(WorkspaceClientCapabilities {
322 configuration: Some(true),
323 did_change_watched_files: Some(DynamicRegistrationClientCapabilities {
324 dynamic_registration: Some(true),
325 }),
326 did_change_configuration: Some(DynamicRegistrationClientCapabilities {
327 dynamic_registration: Some(true),
328 }),
329 ..Default::default()
330 }),
331 text_document: Some(TextDocumentClientCapabilities {
332 definition: Some(GotoCapability {
333 link_support: Some(true),
334 ..Default::default()
335 }),
336 code_action: Some(CodeActionClientCapabilities {
337 code_action_literal_support: Some(CodeActionLiteralSupport {
338 code_action_kind: CodeActionKindLiteralSupport {
339 value_set: vec![
340 CodeActionKind::REFACTOR.as_str().into(),
341 CodeActionKind::QUICKFIX.as_str().into(),
342 CodeActionKind::SOURCE.as_str().into(),
343 ],
344 },
345 }),
346 data_support: Some(true),
347 resolve_support: Some(CodeActionCapabilityResolveSupport {
348 properties: vec!["edit".to_string(), "command".to_string()],
349 }),
350 ..Default::default()
351 }),
352 completion: Some(CompletionClientCapabilities {
353 completion_item: Some(CompletionItemCapability {
354 snippet_support: Some(true),
355 resolve_support: Some(CompletionItemCapabilityResolveSupport {
356 properties: vec!["additionalTextEdits".to_string()],
357 }),
358 ..Default::default()
359 }),
360 ..Default::default()
361 }),
362 rename: Some(RenameClientCapabilities {
363 prepare_support: Some(true),
364 ..Default::default()
365 }),
366 hover: Some(HoverClientCapabilities {
367 content_format: Some(vec![MarkupKind::Markdown]),
368 ..Default::default()
369 }),
370 ..Default::default()
371 }),
372 experimental: Some(json!({
373 "serverStatusNotification": true,
374 })),
375 window: Some(WindowClientCapabilities {
376 work_done_progress: Some(true),
377 ..Default::default()
378 }),
379 ..Default::default()
380 },
381 trace: Default::default(),
382 workspace_folders: Some(vec![WorkspaceFolder {
383 uri: root_uri,
384 name: Default::default(),
385 }]),
386 client_info: Default::default(),
387 locale: Default::default(),
388 };
389
390 let response = self.request::<request::Initialize>(params).await?;
391 if let Some(info) = response.server_info {
392 self.name = info.name;
393 }
394 self.capabilities = response.capabilities;
395
396 self.notify::<notification::Initialized>(InitializedParams {})?;
397 Ok(Arc::new(self))
398 }
399
400 pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
401 if let Some(tasks) = self.io_tasks.lock().take() {
402 let response_handlers = self.response_handlers.clone();
403 let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
404 let outbound_tx = self.outbound_tx.clone();
405 let mut output_done = self.output_done_rx.lock().take().unwrap();
406 let shutdown_request = Self::request_internal::<request::Shutdown>(
407 &next_id,
408 &response_handlers,
409 &outbound_tx,
410 (),
411 );
412 let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
413 outbound_tx.close();
414 Some(
415 async move {
416 log::debug!("language server shutdown started");
417 shutdown_request.await?;
418 response_handlers.lock().take();
419 exit?;
420 output_done.recv().await;
421 log::debug!("language server shutdown finished");
422 drop(tasks);
423 anyhow::Ok(())
424 }
425 .log_err(),
426 )
427 } else {
428 None
429 }
430 }
431
432 #[must_use]
433 pub fn on_notification<T, F>(&self, f: F) -> Subscription
434 where
435 T: notification::Notification,
436 F: 'static + Send + FnMut(T::Params, AsyncAppContext),
437 {
438 self.on_custom_notification(T::METHOD, f)
439 }
440
441 #[must_use]
442 pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
443 where
444 T: request::Request,
445 T::Params: 'static + Send,
446 F: 'static + Send + FnMut(T::Params, AsyncAppContext) -> Fut,
447 Fut: 'static + Future<Output = Result<T::Result>>,
448 {
449 self.on_custom_request(T::METHOD, f)
450 }
451
452 pub fn remove_request_handler<T: request::Request>(&self) {
453 self.notification_handlers.lock().remove(T::METHOD);
454 }
455
456 pub fn remove_notification_handler<T: notification::Notification>(&self) {
457 self.notification_handlers.lock().remove(T::METHOD);
458 }
459
460 #[must_use]
461 pub fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
462 where
463 F: 'static + Send + FnMut(Params, AsyncAppContext),
464 Params: DeserializeOwned,
465 {
466 let prev_handler = self.notification_handlers.lock().insert(
467 method,
468 Box::new(move |_, params, cx| {
469 if let Some(params) = serde_json::from_str(params).log_err() {
470 f(params, cx);
471 }
472 }),
473 );
474 assert!(
475 prev_handler.is_none(),
476 "registered multiple handlers for the same LSP method"
477 );
478 Subscription {
479 method,
480 notification_handlers: self.notification_handlers.clone(),
481 }
482 }
483
484 #[must_use]
485 pub fn on_custom_request<Params, Res, Fut, F>(
486 &self,
487 method: &'static str,
488 mut f: F,
489 ) -> Subscription
490 where
491 F: 'static + Send + FnMut(Params, AsyncAppContext) -> Fut,
492 Fut: 'static + Future<Output = Result<Res>>,
493 Params: DeserializeOwned + Send + 'static,
494 Res: Serialize,
495 {
496 let outbound_tx = self.outbound_tx.clone();
497 let prev_handler = self.notification_handlers.lock().insert(
498 method,
499 Box::new(move |id, params, cx| {
500 if let Some(id) = id {
501 match serde_json::from_str(params) {
502 Ok(params) => {
503 let response = f(params, cx.clone());
504 cx.foreground()
505 .spawn({
506 let outbound_tx = outbound_tx.clone();
507 async move {
508 let response = match response.await {
509 Ok(result) => Response {
510 jsonrpc: JSON_RPC_VERSION,
511 id,
512 result: Some(result),
513 error: None,
514 },
515 Err(error) => Response {
516 jsonrpc: JSON_RPC_VERSION,
517 id,
518 result: None,
519 error: Some(Error {
520 message: error.to_string(),
521 }),
522 },
523 };
524 if let Some(response) =
525 serde_json::to_vec(&response).log_err()
526 {
527 outbound_tx.try_send(response).ok();
528 }
529 }
530 })
531 .detach();
532 }
533 Err(error) => {
534 log::error!(
535 "error deserializing {} request: {:?}, message: {:?}",
536 method,
537 error,
538 params
539 );
540 let response = AnyResponse {
541 jsonrpc: JSON_RPC_VERSION,
542 id,
543 result: None,
544 error: Some(Error {
545 message: error.to_string(),
546 }),
547 };
548 if let Some(response) = serde_json::to_vec(&response).log_err() {
549 outbound_tx.try_send(response).ok();
550 }
551 }
552 }
553 }
554 }),
555 );
556 assert!(
557 prev_handler.is_none(),
558 "registered multiple handlers for the same LSP method"
559 );
560 Subscription {
561 method,
562 notification_handlers: self.notification_handlers.clone(),
563 }
564 }
565
566 pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
567 &self.name
568 }
569
570 pub fn capabilities<'a>(self: &'a Arc<Self>) -> &'a ServerCapabilities {
571 &self.capabilities
572 }
573
574 pub fn server_id(&self) -> usize {
575 self.server_id
576 }
577
578 pub fn root_path(&self) -> &PathBuf {
579 &self.root_path
580 }
581
582 pub fn request<T: request::Request>(
583 &self,
584 params: T::Params,
585 ) -> impl Future<Output = Result<T::Result>>
586 where
587 T::Result: 'static + Send,
588 {
589 Self::request_internal::<T>(
590 &self.next_id,
591 &self.response_handlers,
592 &self.outbound_tx,
593 params,
594 )
595 }
596
597 fn request_internal<T: request::Request>(
598 next_id: &AtomicUsize,
599 response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
600 outbound_tx: &channel::Sender<Vec<u8>>,
601 params: T::Params,
602 ) -> impl 'static + Future<Output = Result<T::Result>>
603 where
604 T::Result: 'static + Send,
605 {
606 let id = next_id.fetch_add(1, SeqCst);
607 let message = serde_json::to_vec(&Request {
608 jsonrpc: JSON_RPC_VERSION,
609 id,
610 method: T::METHOD,
611 params,
612 })
613 .unwrap();
614
615 let (tx, rx) = oneshot::channel();
616 let handle_response = response_handlers
617 .lock()
618 .as_mut()
619 .ok_or_else(|| anyhow!("server shut down"))
620 .map(|handlers| {
621 handlers.insert(
622 id,
623 Box::new(move |result| {
624 let response = match result {
625 Ok(response) => serde_json::from_str(response)
626 .context("failed to deserialize response"),
627 Err(error) => Err(anyhow!("{}", error.message)),
628 };
629 let _ = tx.send(response);
630 }),
631 );
632 });
633
634 let send = outbound_tx
635 .try_send(message)
636 .context("failed to write to language server's stdin");
637
638 async move {
639 handle_response?;
640 send?;
641 rx.await?
642 }
643 }
644
645 pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
646 Self::notify_internal::<T>(&self.outbound_tx, params)
647 }
648
649 fn notify_internal<T: notification::Notification>(
650 outbound_tx: &channel::Sender<Vec<u8>>,
651 params: T::Params,
652 ) -> Result<()> {
653 let message = serde_json::to_vec(&Notification {
654 jsonrpc: JSON_RPC_VERSION,
655 method: T::METHOD,
656 params,
657 })
658 .unwrap();
659 outbound_tx.try_send(message)?;
660 Ok(())
661 }
662}
663
664impl Drop for LanguageServer {
665 fn drop(&mut self) {
666 if let Some(shutdown) = self.shutdown() {
667 self.executor.spawn(shutdown).detach();
668 }
669 }
670}
671
672impl Subscription {
673 pub fn detach(mut self) {
674 self.method = "";
675 }
676}
677
678impl Drop for Subscription {
679 fn drop(&mut self) {
680 self.notification_handlers.lock().remove(self.method);
681 }
682}
683
684#[cfg(any(test, feature = "test-support"))]
685#[derive(Clone)]
686pub struct FakeLanguageServer {
687 pub server: Arc<LanguageServer>,
688 notifications_rx: channel::Receiver<(String, String)>,
689}
690
691#[cfg(any(test, feature = "test-support"))]
692impl LanguageServer {
693 pub fn full_capabilities() -> ServerCapabilities {
694 ServerCapabilities {
695 document_highlight_provider: Some(OneOf::Left(true)),
696 code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
697 document_formatting_provider: Some(OneOf::Left(true)),
698 document_range_formatting_provider: Some(OneOf::Left(true)),
699 ..Default::default()
700 }
701 }
702
703 pub fn fake(
704 name: String,
705 capabilities: ServerCapabilities,
706 cx: AsyncAppContext,
707 ) -> (Self, FakeLanguageServer) {
708 let (stdin_writer, stdin_reader) = async_pipe::pipe();
709 let (stdout_writer, stdout_reader) = async_pipe::pipe();
710 let (notifications_tx, notifications_rx) = channel::unbounded();
711
712 let server = Self::new_internal(
713 0,
714 stdin_writer,
715 stdout_reader,
716 None,
717 Path::new("/"),
718 cx.clone(),
719 |_| {},
720 );
721 let fake = FakeLanguageServer {
722 server: Arc::new(Self::new_internal(
723 0,
724 stdout_writer,
725 stdin_reader,
726 None,
727 Path::new("/"),
728 cx,
729 move |msg| {
730 notifications_tx
731 .try_send((msg.method.to_string(), msg.params.get().to_string()))
732 .ok();
733 },
734 )),
735 notifications_rx,
736 };
737 fake.handle_request::<request::Initialize, _, _>({
738 let capabilities = capabilities;
739 move |_, _| {
740 let capabilities = capabilities.clone();
741 let name = name.clone();
742 async move {
743 Ok(InitializeResult {
744 capabilities,
745 server_info: Some(ServerInfo {
746 name,
747 ..Default::default()
748 }),
749 })
750 }
751 }
752 });
753
754 (server, fake)
755 }
756}
757
758#[cfg(any(test, feature = "test-support"))]
759impl FakeLanguageServer {
760 pub fn notify<T: notification::Notification>(&self, params: T::Params) {
761 self.server.notify::<T>(params).ok();
762 }
763
764 pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
765 where
766 T: request::Request,
767 T::Result: 'static + Send,
768 {
769 self.server.request::<T>(params).await
770 }
771
772 pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
773 self.try_receive_notification::<T>().await.unwrap()
774 }
775
776 pub async fn try_receive_notification<T: notification::Notification>(
777 &mut self,
778 ) -> Option<T::Params> {
779 use futures::StreamExt as _;
780
781 loop {
782 let (method, params) = self.notifications_rx.next().await?;
783 if method == T::METHOD {
784 return Some(serde_json::from_str::<T::Params>(¶ms).unwrap());
785 } else {
786 log::info!("skipping message in fake language server {:?}", params);
787 }
788 }
789 }
790
791 pub fn handle_request<T, F, Fut>(
792 &self,
793 mut handler: F,
794 ) -> futures::channel::mpsc::UnboundedReceiver<()>
795 where
796 T: 'static + request::Request,
797 T::Params: 'static + Send,
798 F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
799 Fut: 'static + Send + Future<Output = Result<T::Result>>,
800 {
801 let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
802 self.server.remove_request_handler::<T>();
803 self.server
804 .on_request::<T, _, _>(move |params, cx| {
805 let result = handler(params, cx.clone());
806 let responded_tx = responded_tx.clone();
807 async move {
808 cx.background().simulate_random_delay().await;
809 let result = result.await;
810 responded_tx.unbounded_send(()).ok();
811 result
812 }
813 })
814 .detach();
815 responded_rx
816 }
817
818 pub fn handle_notification<T, F>(
819 &self,
820 mut handler: F,
821 ) -> futures::channel::mpsc::UnboundedReceiver<()>
822 where
823 T: 'static + notification::Notification,
824 T::Params: 'static + Send,
825 F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext),
826 {
827 let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded();
828 self.server.remove_notification_handler::<T>();
829 self.server
830 .on_notification::<T, _>(move |params, cx| {
831 handler(params, cx.clone());
832 handled_tx.unbounded_send(()).ok();
833 })
834 .detach();
835 handled_rx
836 }
837
838 pub fn remove_request_handler<T>(&mut self)
839 where
840 T: 'static + request::Request,
841 {
842 self.server.remove_request_handler::<T>();
843 }
844
845 pub async fn start_progress(&self, token: impl Into<String>) {
846 let token = token.into();
847 self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
848 token: NumberOrString::String(token.clone()),
849 })
850 .await
851 .unwrap();
852 self.notify::<notification::Progress>(ProgressParams {
853 token: NumberOrString::String(token),
854 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
855 });
856 }
857
858 pub fn end_progress(&self, token: impl Into<String>) {
859 self.notify::<notification::Progress>(ProgressParams {
860 token: NumberOrString::String(token.into()),
861 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
862 });
863 }
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869 use gpui::TestAppContext;
870
871 #[ctor::ctor]
872 fn init_logger() {
873 if std::env::var("RUST_LOG").is_ok() {
874 env_logger::init();
875 }
876 }
877
878 #[gpui::test]
879 async fn test_fake(cx: &mut TestAppContext) {
880 let (server, mut fake) =
881 LanguageServer::fake("the-lsp".to_string(), Default::default(), cx.to_async());
882
883 let (message_tx, message_rx) = channel::unbounded();
884 let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
885 server
886 .on_notification::<notification::ShowMessage, _>(move |params, _| {
887 message_tx.try_send(params).unwrap()
888 })
889 .detach();
890 server
891 .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
892 diagnostics_tx.try_send(params).unwrap()
893 })
894 .detach();
895
896 let server = server.initialize(None).await.unwrap();
897 server
898 .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
899 text_document: TextDocumentItem::new(
900 Url::from_str("file://a/b").unwrap(),
901 "rust".to_string(),
902 0,
903 "".to_string(),
904 ),
905 })
906 .unwrap();
907 assert_eq!(
908 fake.receive_notification::<notification::DidOpenTextDocument>()
909 .await
910 .text_document
911 .uri
912 .as_str(),
913 "file://a/b"
914 );
915
916 fake.notify::<notification::ShowMessage>(ShowMessageParams {
917 typ: MessageType::ERROR,
918 message: "ok".to_string(),
919 });
920 fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
921 uri: Url::from_str("file://b/c").unwrap(),
922 version: Some(5),
923 diagnostics: vec![],
924 });
925 assert_eq!(message_rx.recv().await.unwrap().message, "ok");
926 assert_eq!(
927 diagnostics_rx.recv().await.unwrap().uri.as_str(),
928 "file://b/c"
929 );
930
931 fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
932
933 drop(server);
934 fake.receive_notification::<notification::Exit>().await;
935 }
936}