lsp.rs

  1pub use lsp_types::request::*;
  2pub use lsp_types::*;
  3
  4use anyhow::{anyhow, Context, Result};
  5use collections::HashMap;
  6use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
  7use gpui::{executor, AsyncAppContext, Task};
  8use parking_lot::Mutex;
  9use postage::{barrier, prelude::Stream};
 10use serde::{de::DeserializeOwned, Deserialize, Serialize};
 11use serde_json::{json, value::RawValue, Value};
 12use smol::{
 13    channel,
 14    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 15    process::{self, Child},
 16};
 17use std::{
 18    future::Future,
 19    io::Write,
 20    path::PathBuf,
 21    str::FromStr,
 22    sync::{
 23        atomic::{AtomicUsize, Ordering::SeqCst},
 24        Arc,
 25    },
 26};
 27use std::{path::Path, process::Stdio};
 28use util::{ResultExt, TryFutureExt};
 29
 30const JSON_RPC_VERSION: &str = "2.0";
 31const CONTENT_LEN_HEADER: &str = "Content-Length: ";
 32
 33type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
 34type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 35
 36pub struct LanguageServer {
 37    server_id: usize,
 38    next_id: AtomicUsize,
 39    outbound_tx: channel::Sender<Vec<u8>>,
 40    name: String,
 41    capabilities: ServerCapabilities,
 42    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 43    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 44    executor: Arc<executor::Background>,
 45    #[allow(clippy::type_complexity)]
 46    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 47    output_done_rx: Mutex<Option<barrier::Receiver>>,
 48    root_path: PathBuf,
 49    _server: Option<Child>,
 50}
 51
 52pub struct Subscription {
 53    method: &'static str,
 54    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 55}
 56
 57#[derive(Serialize, Deserialize)]
 58struct Request<'a, T> {
 59    jsonrpc: &'static str,
 60    id: usize,
 61    method: &'a str,
 62    params: T,
 63}
 64
 65#[derive(Serialize, Deserialize)]
 66struct AnyResponse<'a> {
 67    id: usize,
 68    #[serde(default)]
 69    error: Option<Error>,
 70    #[serde(borrow)]
 71    result: Option<&'a RawValue>,
 72}
 73
 74#[derive(Serialize)]
 75struct Response<T> {
 76    jsonrpc: &'static str,
 77    id: usize,
 78    result: Option<T>,
 79    error: Option<Error>,
 80}
 81
 82#[derive(Serialize, Deserialize)]
 83struct Notification<'a, T> {
 84    jsonrpc: &'static str,
 85    #[serde(borrow)]
 86    method: &'a str,
 87    params: T,
 88}
 89
 90#[derive(Deserialize)]
 91struct AnyNotification<'a> {
 92    #[serde(default)]
 93    id: Option<usize>,
 94    #[serde(borrow)]
 95    method: &'a str,
 96    #[serde(borrow)]
 97    params: &'a RawValue,
 98}
 99
100#[derive(Debug, Serialize, Deserialize)]
101struct Error {
102    message: String,
103}
104
105impl LanguageServer {
106    pub fn new<T: AsRef<std::ffi::OsStr>>(
107        server_id: usize,
108        binary_path: &Path,
109        args: &[T],
110        root_path: &Path,
111        cx: AsyncAppContext,
112    ) -> Result<Self> {
113        let working_dir = if root_path.is_dir() {
114            root_path
115        } else {
116            root_path.parent().unwrap_or_else(|| Path::new("/"))
117        };
118        let mut server = process::Command::new(binary_path)
119            .current_dir(working_dir)
120            .args(args)
121            .stdin(Stdio::piped())
122            .stdout(Stdio::piped())
123            .stderr(Stdio::inherit())
124            .kill_on_drop(true)
125            .spawn()?;
126
127        let stdin = server.stdin.take().unwrap();
128        let stout = server.stdout.take().unwrap();
129
130        let mut server = Self::new_internal(
131            server_id,
132            stdin,
133            stout,
134            Some(server),
135            root_path,
136            cx,
137            |notification| {
138                log::info!(
139                    "unhandled notification {}:\n{}",
140                    notification.method,
141                    serde_json::to_string_pretty(
142                        &Value::from_str(notification.params.get()).unwrap()
143                    )
144                    .unwrap()
145                );
146            },
147        );
148        if let Some(name) = binary_path.file_name() {
149            server.name = name.to_string_lossy().to_string();
150        }
151        Ok(server)
152    }
153
154    fn new_internal<Stdin, Stdout, F>(
155        server_id: usize,
156        stdin: Stdin,
157        stdout: Stdout,
158        server: Option<Child>,
159        root_path: &Path,
160        cx: AsyncAppContext,
161        mut on_unhandled_notification: F,
162    ) -> Self
163    where
164        Stdin: AsyncWrite + Unpin + Send + 'static,
165        Stdout: AsyncRead + Unpin + Send + 'static,
166        F: FnMut(AnyNotification) + 'static + Send,
167    {
168        let mut stdin = BufWriter::new(stdin);
169        let mut stdout = BufReader::new(stdout);
170        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
171        let notification_handlers =
172            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
173        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default()));
174        let input_task = cx.spawn(|cx| {
175            let notification_handlers = notification_handlers.clone();
176            let response_handlers = response_handlers.clone();
177            async move {
178                let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
179                let mut buffer = Vec::new();
180                loop {
181                    buffer.clear();
182                    stdout.read_until(b'\n', &mut buffer).await?;
183                    stdout.read_until(b'\n', &mut buffer).await?;
184                    let message_len: usize = std::str::from_utf8(&buffer)?
185                        .strip_prefix(CONTENT_LEN_HEADER)
186                        .ok_or_else(|| anyhow!("invalid header"))?
187                        .trim_end()
188                        .parse()?;
189
190                    buffer.resize(message_len, 0);
191                    stdout.read_exact(&mut buffer).await?;
192                    log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
193
194                    if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
195                        if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
196                            handler(msg.id, msg.params.get(), cx.clone());
197                        } else {
198                            on_unhandled_notification(msg);
199                        }
200                    } else if let Ok(AnyResponse { id, error, result }) =
201                        serde_json::from_slice(&buffer)
202                    {
203                        if let Some(handler) = response_handlers.lock().remove(&id) {
204                            if let Some(error) = error {
205                                handler(Err(error));
206                            } else if let Some(result) = result {
207                                handler(Ok(result.get()));
208                            } else {
209                                handler(Ok("null"));
210                            }
211                        }
212                    } else {
213                        return Err(anyhow!(
214                            "failed to deserialize message:\n{}",
215                            std::str::from_utf8(&buffer)?
216                        ));
217                    }
218
219                    // Don't starve the main thread when receiving lots of messages at once.
220                    smol::future::yield_now().await;
221                }
222            }
223            .log_err()
224        });
225        let (output_done_tx, output_done_rx) = barrier::channel();
226        let output_task = cx.background().spawn({
227            let response_handlers = response_handlers.clone();
228            async move {
229                let _clear_response_handlers = ClearResponseHandlers(response_handlers);
230                let mut content_len_buffer = Vec::new();
231                while let Ok(message) = outbound_rx.recv().await {
232                    log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
233                    content_len_buffer.clear();
234                    write!(content_len_buffer, "{}", message.len()).unwrap();
235                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
236                    stdin.write_all(&content_len_buffer).await?;
237                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
238                    stdin.write_all(&message).await?;
239                    stdin.flush().await?;
240                }
241                drop(output_done_tx);
242                Ok(())
243            }
244            .log_err()
245        });
246
247        Self {
248            server_id,
249            notification_handlers,
250            response_handlers,
251            name: Default::default(),
252            capabilities: Default::default(),
253            next_id: Default::default(),
254            outbound_tx,
255            executor: cx.background(),
256            io_tasks: Mutex::new(Some((input_task, output_task))),
257            output_done_rx: Mutex::new(Some(output_done_rx)),
258            root_path: root_path.to_path_buf(),
259            _server: server,
260        }
261    }
262
263    /// Initializes a language server.
264    /// Note that `options` is used directly to construct [`InitializeParams`],
265    /// which is why it is owned.
266    pub async fn initialize(mut self, options: Option<Value>) -> Result<Arc<Self>> {
267        let root_uri = Url::from_file_path(&self.root_path).unwrap();
268        #[allow(deprecated)]
269        let params = InitializeParams {
270            process_id: Default::default(),
271            root_path: Default::default(),
272            root_uri: Some(root_uri.clone()),
273            initialization_options: options,
274            capabilities: ClientCapabilities {
275                workspace: Some(WorkspaceClientCapabilities {
276                    configuration: Some(true),
277                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
278                        dynamic_registration: Some(true),
279                    }),
280                    ..Default::default()
281                }),
282                text_document: Some(TextDocumentClientCapabilities {
283                    definition: Some(GotoCapability {
284                        link_support: Some(true),
285                        ..Default::default()
286                    }),
287                    code_action: Some(CodeActionClientCapabilities {
288                        code_action_literal_support: Some(CodeActionLiteralSupport {
289                            code_action_kind: CodeActionKindLiteralSupport {
290                                value_set: vec![
291                                    CodeActionKind::REFACTOR.as_str().into(),
292                                    CodeActionKind::QUICKFIX.as_str().into(),
293                                    CodeActionKind::SOURCE.as_str().into(),
294                                ],
295                            },
296                        }),
297                        data_support: Some(true),
298                        resolve_support: Some(CodeActionCapabilityResolveSupport {
299                            properties: vec!["edit".to_string(), "command".to_string()],
300                        }),
301                        ..Default::default()
302                    }),
303                    completion: Some(CompletionClientCapabilities {
304                        completion_item: Some(CompletionItemCapability {
305                            snippet_support: Some(true),
306                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
307                                properties: vec!["additionalTextEdits".to_string()],
308                            }),
309                            ..Default::default()
310                        }),
311                        ..Default::default()
312                    }),
313                    rename: Some(RenameClientCapabilities {
314                        prepare_support: Some(true),
315                        ..Default::default()
316                    }),
317                    hover: Some(HoverClientCapabilities {
318                        content_format: Some(vec![MarkupKind::Markdown]),
319                        ..Default::default()
320                    }),
321                    ..Default::default()
322                }),
323                experimental: Some(json!({
324                    "serverStatusNotification": true,
325                })),
326                window: Some(WindowClientCapabilities {
327                    work_done_progress: Some(true),
328                    ..Default::default()
329                }),
330                ..Default::default()
331            },
332            trace: Default::default(),
333            workspace_folders: Some(vec![WorkspaceFolder {
334                uri: root_uri,
335                name: Default::default(),
336            }]),
337            client_info: Default::default(),
338            locale: Default::default(),
339        };
340
341        let response = self.request::<request::Initialize>(params).await?;
342        if let Some(info) = response.server_info {
343            self.name = info.name;
344        }
345        self.capabilities = response.capabilities;
346
347        self.notify::<notification::Initialized>(InitializedParams {})?;
348        Ok(Arc::new(self))
349    }
350
351    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
352        if let Some(tasks) = self.io_tasks.lock().take() {
353            let response_handlers = self.response_handlers.clone();
354            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
355            let outbound_tx = self.outbound_tx.clone();
356            let mut output_done = self.output_done_rx.lock().take().unwrap();
357            let shutdown_request = Self::request_internal::<request::Shutdown>(
358                &next_id,
359                &response_handlers,
360                &outbound_tx,
361                (),
362            );
363            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
364            outbound_tx.close();
365            Some(
366                async move {
367                    log::debug!("language server shutdown started");
368                    shutdown_request.await?;
369                    response_handlers.lock().clear();
370                    exit?;
371                    output_done.recv().await;
372                    log::debug!("language server shutdown finished");
373                    drop(tasks);
374                    Ok(())
375                }
376                .log_err(),
377            )
378        } else {
379            None
380        }
381    }
382
383    #[must_use]
384    pub fn on_notification<T, F>(&self, f: F) -> Subscription
385    where
386        T: notification::Notification,
387        F: 'static + Send + FnMut(T::Params, AsyncAppContext),
388    {
389        self.on_custom_notification(T::METHOD, f)
390    }
391
392    #[must_use]
393    pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
394    where
395        T: request::Request,
396        T::Params: 'static + Send,
397        F: 'static + Send + FnMut(T::Params, AsyncAppContext) -> Fut,
398        Fut: 'static + Future<Output = Result<T::Result>>,
399    {
400        self.on_custom_request(T::METHOD, f)
401    }
402
403    pub fn remove_request_handler<T: request::Request>(&self) {
404        self.notification_handlers.lock().remove(T::METHOD);
405    }
406
407    #[must_use]
408    pub fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
409    where
410        F: 'static + Send + FnMut(Params, AsyncAppContext),
411        Params: DeserializeOwned,
412    {
413        let prev_handler = self.notification_handlers.lock().insert(
414            method,
415            Box::new(move |_, params, cx| {
416                if let Some(params) = serde_json::from_str(params).log_err() {
417                    f(params, cx);
418                }
419            }),
420        );
421        assert!(
422            prev_handler.is_none(),
423            "registered multiple handlers for the same LSP method"
424        );
425        Subscription {
426            method,
427            notification_handlers: self.notification_handlers.clone(),
428        }
429    }
430
431    #[must_use]
432    pub fn on_custom_request<Params, Res, Fut, F>(
433        &self,
434        method: &'static str,
435        mut f: F,
436    ) -> Subscription
437    where
438        F: 'static + Send + FnMut(Params, AsyncAppContext) -> Fut,
439        Fut: 'static + Future<Output = Result<Res>>,
440        Params: DeserializeOwned + Send + 'static,
441        Res: Serialize,
442    {
443        let outbound_tx = self.outbound_tx.clone();
444        let prev_handler = self.notification_handlers.lock().insert(
445            method,
446            Box::new(move |id, params, cx| {
447                if let Some(id) = id {
448                    if let Some(params) = serde_json::from_str(params).log_err() {
449                        let response = f(params, cx.clone());
450                        cx.foreground()
451                            .spawn({
452                                let outbound_tx = outbound_tx.clone();
453                                async move {
454                                    let response = match response.await {
455                                        Ok(result) => Response {
456                                            jsonrpc: JSON_RPC_VERSION,
457                                            id,
458                                            result: Some(result),
459                                            error: None,
460                                        },
461                                        Err(error) => Response {
462                                            jsonrpc: JSON_RPC_VERSION,
463                                            id,
464                                            result: None,
465                                            error: Some(Error {
466                                                message: error.to_string(),
467                                            }),
468                                        },
469                                    };
470                                    if let Some(response) = serde_json::to_vec(&response).log_err()
471                                    {
472                                        outbound_tx.try_send(response).ok();
473                                    }
474                                }
475                            })
476                            .detach();
477                    }
478                }
479            }),
480        );
481        assert!(
482            prev_handler.is_none(),
483            "registered multiple handlers for the same LSP method"
484        );
485        Subscription {
486            method,
487            notification_handlers: self.notification_handlers.clone(),
488        }
489    }
490
491    pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
492        &self.name
493    }
494
495    pub fn capabilities<'a>(self: &'a Arc<Self>) -> &'a ServerCapabilities {
496        &self.capabilities
497    }
498
499    pub fn server_id(&self) -> usize {
500        self.server_id
501    }
502
503    pub fn root_path(&self) -> &PathBuf {
504        &self.root_path
505    }
506
507    pub fn request<T: request::Request>(
508        &self,
509        params: T::Params,
510    ) -> impl Future<Output = Result<T::Result>>
511    where
512        T::Result: 'static + Send,
513    {
514        Self::request_internal::<T>(
515            &self.next_id,
516            &self.response_handlers,
517            &self.outbound_tx,
518            params,
519        )
520    }
521
522    fn request_internal<T: request::Request>(
523        next_id: &AtomicUsize,
524        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
525        outbound_tx: &channel::Sender<Vec<u8>>,
526        params: T::Params,
527    ) -> impl 'static + Future<Output = Result<T::Result>>
528    where
529        T::Result: 'static + Send,
530    {
531        let id = next_id.fetch_add(1, SeqCst);
532        let message = serde_json::to_vec(&Request {
533            jsonrpc: JSON_RPC_VERSION,
534            id,
535            method: T::METHOD,
536            params,
537        })
538        .unwrap();
539
540        let send = outbound_tx
541            .try_send(message)
542            .context("failed to write to language server's stdin");
543
544        let (tx, rx) = oneshot::channel();
545        response_handlers.lock().insert(
546            id,
547            Box::new(move |result| {
548                let response = match result {
549                    Ok(response) => {
550                        serde_json::from_str(response).context("failed to deserialize response")
551                    }
552                    Err(error) => Err(anyhow!("{}", error.message)),
553                };
554                let _ = tx.send(response);
555            }),
556        );
557
558        async move {
559            send?;
560            rx.await?
561        }
562    }
563
564    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
565        Self::notify_internal::<T>(&self.outbound_tx, params)
566    }
567
568    fn notify_internal<T: notification::Notification>(
569        outbound_tx: &channel::Sender<Vec<u8>>,
570        params: T::Params,
571    ) -> Result<()> {
572        let message = serde_json::to_vec(&Notification {
573            jsonrpc: JSON_RPC_VERSION,
574            method: T::METHOD,
575            params,
576        })
577        .unwrap();
578        outbound_tx.try_send(message)?;
579        Ok(())
580    }
581}
582
583impl Drop for LanguageServer {
584    fn drop(&mut self) {
585        if let Some(shutdown) = self.shutdown() {
586            self.executor.spawn(shutdown).detach();
587        }
588    }
589}
590
591impl Subscription {
592    pub fn detach(mut self) {
593        self.method = "";
594    }
595}
596
597impl Drop for Subscription {
598    fn drop(&mut self) {
599        self.notification_handlers.lock().remove(self.method);
600    }
601}
602
603#[cfg(any(test, feature = "test-support"))]
604#[derive(Clone)]
605pub struct FakeLanguageServer {
606    pub server: Arc<LanguageServer>,
607    notifications_rx: channel::Receiver<(String, String)>,
608}
609
610#[cfg(any(test, feature = "test-support"))]
611impl LanguageServer {
612    pub fn full_capabilities() -> ServerCapabilities {
613        ServerCapabilities {
614            document_highlight_provider: Some(OneOf::Left(true)),
615            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
616            document_formatting_provider: Some(OneOf::Left(true)),
617            document_range_formatting_provider: Some(OneOf::Left(true)),
618            ..Default::default()
619        }
620    }
621
622    pub fn fake(
623        name: String,
624        capabilities: ServerCapabilities,
625        cx: AsyncAppContext,
626    ) -> (Self, FakeLanguageServer) {
627        let (stdin_writer, stdin_reader) = async_pipe::pipe();
628        let (stdout_writer, stdout_reader) = async_pipe::pipe();
629        let (notifications_tx, notifications_rx) = channel::unbounded();
630
631        let server = Self::new_internal(
632            0,
633            stdin_writer,
634            stdout_reader,
635            None,
636            Path::new("/"),
637            cx.clone(),
638            |_| {},
639        );
640        let fake = FakeLanguageServer {
641            server: Arc::new(Self::new_internal(
642                0,
643                stdout_writer,
644                stdin_reader,
645                None,
646                Path::new("/"),
647                cx,
648                move |msg| {
649                    notifications_tx
650                        .try_send((msg.method.to_string(), msg.params.get().to_string()))
651                        .ok();
652                },
653            )),
654            notifications_rx,
655        };
656        fake.handle_request::<request::Initialize, _, _>({
657            let capabilities = capabilities;
658            move |_, _| {
659                let capabilities = capabilities.clone();
660                let name = name.clone();
661                async move {
662                    Ok(InitializeResult {
663                        capabilities,
664                        server_info: Some(ServerInfo {
665                            name,
666                            ..Default::default()
667                        }),
668                    })
669                }
670            }
671        });
672
673        (server, fake)
674    }
675}
676
677#[cfg(any(test, feature = "test-support"))]
678impl FakeLanguageServer {
679    pub fn notify<T: notification::Notification>(&self, params: T::Params) {
680        self.server.notify::<T>(params).ok();
681    }
682
683    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
684    where
685        T: request::Request,
686        T::Result: 'static + Send,
687    {
688        self.server.request::<T>(params).await
689    }
690
691    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
692        self.try_receive_notification::<T>().await.unwrap()
693    }
694
695    pub async fn try_receive_notification<T: notification::Notification>(
696        &mut self,
697    ) -> Option<T::Params> {
698        use futures::StreamExt as _;
699
700        loop {
701            let (method, params) = self.notifications_rx.next().await?;
702            if method == T::METHOD {
703                return Some(serde_json::from_str::<T::Params>(&params).unwrap());
704            } else {
705                log::info!("skipping message in fake language server {:?}", params);
706            }
707        }
708    }
709
710    pub fn handle_request<T, F, Fut>(
711        &self,
712        mut handler: F,
713    ) -> futures::channel::mpsc::UnboundedReceiver<()>
714    where
715        T: 'static + request::Request,
716        T::Params: 'static + Send,
717        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
718        Fut: 'static + Send + Future<Output = Result<T::Result>>,
719    {
720        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
721        self.server.remove_request_handler::<T>();
722        self.server
723            .on_request::<T, _, _>(move |params, cx| {
724                let result = handler(params, cx.clone());
725                let responded_tx = responded_tx.clone();
726                async move {
727                    cx.background().simulate_random_delay().await;
728                    let result = result.await;
729                    responded_tx.unbounded_send(()).ok();
730                    result
731                }
732            })
733            .detach();
734        responded_rx
735    }
736
737    pub fn remove_request_handler<T>(&mut self)
738    where
739        T: 'static + request::Request,
740    {
741        self.server.remove_request_handler::<T>();
742    }
743
744    pub async fn start_progress(&self, token: impl Into<String>) {
745        let token = token.into();
746        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
747            token: NumberOrString::String(token.clone()),
748        })
749        .await
750        .unwrap();
751        self.notify::<notification::Progress>(ProgressParams {
752            token: NumberOrString::String(token),
753            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
754        });
755    }
756
757    pub fn end_progress(&self, token: impl Into<String>) {
758        self.notify::<notification::Progress>(ProgressParams {
759            token: NumberOrString::String(token.into()),
760            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
761        });
762    }
763}
764
765struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
766
767impl Drop for ClearResponseHandlers {
768    fn drop(&mut self) {
769        self.0.lock().clear();
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use gpui::TestAppContext;
777
778    #[ctor::ctor]
779    fn init_logger() {
780        if std::env::var("RUST_LOG").is_ok() {
781            env_logger::init();
782        }
783    }
784
785    #[gpui::test]
786    async fn test_fake(cx: &mut TestAppContext) {
787        let (server, mut fake) =
788            LanguageServer::fake("the-lsp".to_string(), Default::default(), cx.to_async());
789
790        let (message_tx, message_rx) = channel::unbounded();
791        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
792        server
793            .on_notification::<notification::ShowMessage, _>(move |params, _| {
794                message_tx.try_send(params).unwrap()
795            })
796            .detach();
797        server
798            .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
799                diagnostics_tx.try_send(params).unwrap()
800            })
801            .detach();
802
803        let server = server.initialize(None).await.unwrap();
804        server
805            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
806                text_document: TextDocumentItem::new(
807                    Url::from_str("file://a/b").unwrap(),
808                    "rust".to_string(),
809                    0,
810                    "".to_string(),
811                ),
812            })
813            .unwrap();
814        assert_eq!(
815            fake.receive_notification::<notification::DidOpenTextDocument>()
816                .await
817                .text_document
818                .uri
819                .as_str(),
820            "file://a/b"
821        );
822
823        fake.notify::<notification::ShowMessage>(ShowMessageParams {
824            typ: MessageType::ERROR,
825            message: "ok".to_string(),
826        });
827        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
828            uri: Url::from_str("file://b/c").unwrap(),
829            version: Some(5),
830            diagnostics: vec![],
831        });
832        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
833        assert_eq!(
834            diagnostics_rx.recv().await.unwrap().uri.as_str(),
835            "file://b/c"
836        );
837
838        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
839
840        drop(server);
841        fake.receive_notification::<notification::Exit>().await;
842    }
843}