lsp.rs

  1use log::warn;
  2pub use lsp_types::request::*;
  3pub use lsp_types::*;
  4
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
  8use gpui::{executor, AsyncAppContext, Task};
  9use parking_lot::Mutex;
 10use postage::{barrier, prelude::Stream};
 11use serde::{de::DeserializeOwned, Deserialize, Serialize};
 12use serde_json::{json, value::RawValue, Value};
 13use smol::{
 14    channel,
 15    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 16    process::{self, Child},
 17};
 18use std::{
 19    fmt,
 20    future::Future,
 21    io::Write,
 22    path::PathBuf,
 23    str::FromStr,
 24    sync::{
 25        atomic::{AtomicUsize, Ordering::SeqCst},
 26        Arc,
 27    },
 28};
 29use std::{path::Path, process::Stdio};
 30use util::{ResultExt, TryFutureExt};
 31
 32const JSON_RPC_VERSION: &str = "2.0";
 33const CONTENT_LEN_HEADER: &str = "Content-Length: ";
 34
 35type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
 36type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 37
 38pub struct LanguageServer {
 39    server_id: LanguageServerId,
 40    next_id: AtomicUsize,
 41    outbound_tx: channel::Sender<Vec<u8>>,
 42    name: String,
 43    capabilities: ServerCapabilities,
 44    code_action_kinds: Option<Vec<CodeActionKind>>,
 45    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 46    response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
 47    executor: Arc<executor::Background>,
 48    #[allow(clippy::type_complexity)]
 49    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 50    output_done_rx: Mutex<Option<barrier::Receiver>>,
 51    root_path: PathBuf,
 52    _server: Option<Child>,
 53}
 54
 55#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 56#[repr(transparent)]
 57pub struct LanguageServerId(pub usize);
 58
 59pub struct Subscription {
 60    method: &'static str,
 61    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 62}
 63
 64#[derive(Serialize, Deserialize)]
 65struct Request<'a, T> {
 66    jsonrpc: &'static str,
 67    id: usize,
 68    method: &'a str,
 69    params: T,
 70}
 71
 72#[derive(Serialize, Deserialize)]
 73struct AnyResponse<'a> {
 74    jsonrpc: &'a str,
 75    id: usize,
 76    #[serde(default)]
 77    error: Option<Error>,
 78    #[serde(borrow)]
 79    result: Option<&'a RawValue>,
 80}
 81
 82#[derive(Serialize)]
 83struct Response<T> {
 84    jsonrpc: &'static str,
 85    id: usize,
 86    result: Option<T>,
 87    error: Option<Error>,
 88}
 89
 90#[derive(Serialize, Deserialize)]
 91struct Notification<'a, T> {
 92    jsonrpc: &'static str,
 93    #[serde(borrow)]
 94    method: &'a str,
 95    params: T,
 96}
 97
 98#[derive(Deserialize)]
 99struct AnyNotification<'a> {
100    #[serde(default)]
101    id: Option<usize>,
102    #[serde(borrow)]
103    method: &'a str,
104    #[serde(borrow)]
105    params: &'a RawValue,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109struct Error {
110    message: String,
111}
112
113impl LanguageServer {
114    pub fn new<T: AsRef<std::ffi::OsStr>>(
115        server_id: LanguageServerId,
116        binary_path: &Path,
117        arguments: &[T],
118        root_path: &Path,
119        code_action_kinds: Option<Vec<CodeActionKind>>,
120        cx: AsyncAppContext,
121    ) -> Result<Self> {
122        let working_dir = if root_path.is_dir() {
123            root_path
124        } else {
125            root_path.parent().unwrap_or_else(|| Path::new("/"))
126        };
127
128        let mut server = process::Command::new(binary_path)
129            .current_dir(working_dir)
130            .args(arguments)
131            .stdin(Stdio::piped())
132            .stdout(Stdio::piped())
133            .stderr(Stdio::inherit())
134            .kill_on_drop(true)
135            .spawn()?;
136
137        let stdin = server.stdin.take().unwrap();
138        let stout = server.stdout.take().unwrap();
139        let mut server = Self::new_internal(
140            server_id,
141            stdin,
142            stout,
143            Some(server),
144            root_path,
145            code_action_kinds,
146            cx,
147            |notification| {
148                log::info!(
149                    "unhandled notification {}:\n{}",
150                    notification.method,
151                    serde_json::to_string_pretty(
152                        &Value::from_str(notification.params.get()).unwrap()
153                    )
154                    .unwrap()
155                );
156            },
157        );
158
159        if let Some(name) = binary_path.file_name() {
160            server.name = name.to_string_lossy().to_string();
161        }
162        Ok(server)
163    }
164
165    fn new_internal<Stdin, Stdout, F>(
166        server_id: LanguageServerId,
167        stdin: Stdin,
168        stdout: Stdout,
169        server: Option<Child>,
170        root_path: &Path,
171        code_action_kinds: Option<Vec<CodeActionKind>>,
172        cx: AsyncAppContext,
173        on_unhandled_notification: F,
174    ) -> Self
175    where
176        Stdin: AsyncWrite + Unpin + Send + 'static,
177        Stdout: AsyncRead + Unpin + Send + 'static,
178        F: FnMut(AnyNotification) + 'static + Send,
179    {
180        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
181        let notification_handlers =
182            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
183        let response_handlers =
184            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
185        let input_task = cx.spawn(|cx| {
186            let notification_handlers = notification_handlers.clone();
187            let response_handlers = response_handlers.clone();
188            Self::handle_input(
189                stdout,
190                on_unhandled_notification,
191                notification_handlers,
192                response_handlers,
193                cx,
194            )
195            .log_err()
196        });
197        let (output_done_tx, output_done_rx) = barrier::channel();
198        let output_task = cx.background().spawn({
199            let response_handlers = response_handlers.clone();
200            Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
201        });
202
203        Self {
204            server_id,
205            notification_handlers,
206            response_handlers,
207            name: Default::default(),
208            capabilities: Default::default(),
209            code_action_kinds,
210            next_id: Default::default(),
211            outbound_tx,
212            executor: cx.background(),
213            io_tasks: Mutex::new(Some((input_task, output_task))),
214            output_done_rx: Mutex::new(Some(output_done_rx)),
215            root_path: root_path.to_path_buf(),
216            _server: server,
217        }
218    }
219
220    pub fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
221        self.code_action_kinds.clone()
222    }
223
224    async fn handle_input<Stdout, F>(
225        stdout: Stdout,
226        mut on_unhandled_notification: F,
227        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
228        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
229        cx: AsyncAppContext,
230    ) -> anyhow::Result<()>
231    where
232        Stdout: AsyncRead + Unpin + Send + 'static,
233        F: FnMut(AnyNotification) + 'static + Send,
234    {
235        let mut stdout = BufReader::new(stdout);
236        let _clear_response_handlers = util::defer({
237            let response_handlers = response_handlers.clone();
238            move || {
239                response_handlers.lock().take();
240            }
241        });
242        let mut buffer = Vec::new();
243        loop {
244            buffer.clear();
245            stdout.read_until(b'\n', &mut buffer).await?;
246            stdout.read_until(b'\n', &mut buffer).await?;
247            let message_len: usize = std::str::from_utf8(&buffer)?
248                .strip_prefix(CONTENT_LEN_HEADER)
249                .ok_or_else(|| anyhow!("invalid header"))?
250                .trim_end()
251                .parse()?;
252
253            buffer.resize(message_len, 0);
254            stdout.read_exact(&mut buffer).await?;
255            log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
256
257            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
258                if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
259                    handler(msg.id, msg.params.get(), cx.clone());
260                } else {
261                    on_unhandled_notification(msg);
262                }
263            } else if let Ok(AnyResponse {
264                id, error, result, ..
265            }) = serde_json::from_slice(&buffer)
266            {
267                if let Some(handler) = response_handlers
268                    .lock()
269                    .as_mut()
270                    .and_then(|handlers| handlers.remove(&id))
271                {
272                    if let Some(error) = error {
273                        handler(Err(error));
274                    } else if let Some(result) = result {
275                        handler(Ok(result.get()));
276                    } else {
277                        handler(Ok("null"));
278                    }
279                }
280            } else {
281                warn!(
282                    "Failed to deserialize message:\n{}",
283                    std::str::from_utf8(&buffer)?
284                );
285            }
286
287            // Don't starve the main thread when receiving lots of messages at once.
288            smol::future::yield_now().await;
289        }
290    }
291
292    async fn handle_output<Stdin>(
293        stdin: Stdin,
294        outbound_rx: channel::Receiver<Vec<u8>>,
295        output_done_tx: barrier::Sender,
296        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
297    ) -> anyhow::Result<()>
298    where
299        Stdin: AsyncWrite + Unpin + Send + 'static,
300    {
301        let mut stdin = BufWriter::new(stdin);
302        let _clear_response_handlers = util::defer({
303            let response_handlers = response_handlers.clone();
304            move || {
305                response_handlers.lock().take();
306            }
307        });
308        let mut content_len_buffer = Vec::new();
309        while let Ok(message) = outbound_rx.recv().await {
310            log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
311            content_len_buffer.clear();
312            write!(content_len_buffer, "{}", message.len()).unwrap();
313            stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
314            stdin.write_all(&content_len_buffer).await?;
315            stdin.write_all("\r\n\r\n".as_bytes()).await?;
316            stdin.write_all(&message).await?;
317            stdin.flush().await?;
318        }
319        drop(output_done_tx);
320        Ok(())
321    }
322
323    /// Initializes a language server.
324    /// Note that `options` is used directly to construct [`InitializeParams`],
325    /// which is why it is owned.
326    pub async fn initialize(mut self, options: Option<Value>) -> Result<Arc<Self>> {
327        let root_uri = Url::from_file_path(&self.root_path).unwrap();
328        #[allow(deprecated)]
329        let params = InitializeParams {
330            process_id: Default::default(),
331            root_path: Default::default(),
332            root_uri: Some(root_uri.clone()),
333            initialization_options: options,
334            capabilities: ClientCapabilities {
335                workspace: Some(WorkspaceClientCapabilities {
336                    configuration: Some(true),
337                    did_change_watched_files: Some(DynamicRegistrationClientCapabilities {
338                        dynamic_registration: Some(true),
339                    }),
340                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
341                        dynamic_registration: Some(true),
342                    }),
343                    workspace_folders: Some(true),
344                    ..Default::default()
345                }),
346                text_document: Some(TextDocumentClientCapabilities {
347                    definition: Some(GotoCapability {
348                        link_support: Some(true),
349                        ..Default::default()
350                    }),
351                    code_action: Some(CodeActionClientCapabilities {
352                        code_action_literal_support: Some(CodeActionLiteralSupport {
353                            code_action_kind: CodeActionKindLiteralSupport {
354                                value_set: vec![
355                                    CodeActionKind::REFACTOR.as_str().into(),
356                                    CodeActionKind::QUICKFIX.as_str().into(),
357                                    CodeActionKind::SOURCE.as_str().into(),
358                                ],
359                            },
360                        }),
361                        data_support: Some(true),
362                        resolve_support: Some(CodeActionCapabilityResolveSupport {
363                            properties: vec!["edit".to_string(), "command".to_string()],
364                        }),
365                        ..Default::default()
366                    }),
367                    completion: Some(CompletionClientCapabilities {
368                        completion_item: Some(CompletionItemCapability {
369                            snippet_support: Some(true),
370                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
371                                properties: vec!["additionalTextEdits".to_string()],
372                            }),
373                            ..Default::default()
374                        }),
375                        ..Default::default()
376                    }),
377                    rename: Some(RenameClientCapabilities {
378                        prepare_support: Some(true),
379                        ..Default::default()
380                    }),
381                    hover: Some(HoverClientCapabilities {
382                        content_format: Some(vec![MarkupKind::Markdown]),
383                        ..Default::default()
384                    }),
385                    ..Default::default()
386                }),
387                experimental: Some(json!({
388                    "serverStatusNotification": true,
389                })),
390                window: Some(WindowClientCapabilities {
391                    work_done_progress: Some(true),
392                    ..Default::default()
393                }),
394                ..Default::default()
395            },
396            trace: Default::default(),
397            workspace_folders: Some(vec![WorkspaceFolder {
398                uri: root_uri,
399                name: Default::default(),
400            }]),
401            client_info: Default::default(),
402            locale: Default::default(),
403        };
404
405        let response = self.request::<request::Initialize>(params).await?;
406        if let Some(info) = response.server_info {
407            self.name = info.name;
408        }
409        self.capabilities = response.capabilities;
410
411        self.notify::<notification::Initialized>(InitializedParams {})?;
412        Ok(Arc::new(self))
413    }
414
415    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
416        if let Some(tasks) = self.io_tasks.lock().take() {
417            let response_handlers = self.response_handlers.clone();
418            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
419            let outbound_tx = self.outbound_tx.clone();
420            let mut output_done = self.output_done_rx.lock().take().unwrap();
421            let shutdown_request = Self::request_internal::<request::Shutdown>(
422                &next_id,
423                &response_handlers,
424                &outbound_tx,
425                (),
426            );
427            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
428            outbound_tx.close();
429            Some(
430                async move {
431                    log::debug!("language server shutdown started");
432                    shutdown_request.await?;
433                    response_handlers.lock().take();
434                    exit?;
435                    output_done.recv().await;
436                    log::debug!("language server shutdown finished");
437                    drop(tasks);
438                    anyhow::Ok(())
439                }
440                .log_err(),
441            )
442        } else {
443            None
444        }
445    }
446
447    #[must_use]
448    pub fn on_notification<T, F>(&self, f: F) -> Subscription
449    where
450        T: notification::Notification,
451        F: 'static + Send + FnMut(T::Params, AsyncAppContext),
452    {
453        self.on_custom_notification(T::METHOD, f)
454    }
455
456    #[must_use]
457    pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
458    where
459        T: request::Request,
460        T::Params: 'static + Send,
461        F: 'static + Send + FnMut(T::Params, AsyncAppContext) -> Fut,
462        Fut: 'static + Future<Output = Result<T::Result>>,
463    {
464        self.on_custom_request(T::METHOD, f)
465    }
466
467    pub fn remove_request_handler<T: request::Request>(&self) {
468        self.notification_handlers.lock().remove(T::METHOD);
469    }
470
471    pub fn remove_notification_handler<T: notification::Notification>(&self) {
472        self.notification_handlers.lock().remove(T::METHOD);
473    }
474
475    #[must_use]
476    pub fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
477    where
478        F: 'static + Send + FnMut(Params, AsyncAppContext),
479        Params: DeserializeOwned,
480    {
481        let prev_handler = self.notification_handlers.lock().insert(
482            method,
483            Box::new(move |_, params, cx| {
484                if let Some(params) = serde_json::from_str(params).log_err() {
485                    f(params, cx);
486                }
487            }),
488        );
489        assert!(
490            prev_handler.is_none(),
491            "registered multiple handlers for the same LSP method"
492        );
493        Subscription {
494            method,
495            notification_handlers: self.notification_handlers.clone(),
496        }
497    }
498
499    #[must_use]
500    pub fn on_custom_request<Params, Res, Fut, F>(
501        &self,
502        method: &'static str,
503        mut f: F,
504    ) -> Subscription
505    where
506        F: 'static + Send + FnMut(Params, AsyncAppContext) -> Fut,
507        Fut: 'static + Future<Output = Result<Res>>,
508        Params: DeserializeOwned + Send + 'static,
509        Res: Serialize,
510    {
511        let outbound_tx = self.outbound_tx.clone();
512        let prev_handler = self.notification_handlers.lock().insert(
513            method,
514            Box::new(move |id, params, cx| {
515                if let Some(id) = id {
516                    match serde_json::from_str(params) {
517                        Ok(params) => {
518                            let response = f(params, cx.clone());
519                            cx.foreground()
520                                .spawn({
521                                    let outbound_tx = outbound_tx.clone();
522                                    async move {
523                                        let response = match response.await {
524                                            Ok(result) => Response {
525                                                jsonrpc: JSON_RPC_VERSION,
526                                                id,
527                                                result: Some(result),
528                                                error: None,
529                                            },
530                                            Err(error) => Response {
531                                                jsonrpc: JSON_RPC_VERSION,
532                                                id,
533                                                result: None,
534                                                error: Some(Error {
535                                                    message: error.to_string(),
536                                                }),
537                                            },
538                                        };
539                                        if let Some(response) =
540                                            serde_json::to_vec(&response).log_err()
541                                        {
542                                            outbound_tx.try_send(response).ok();
543                                        }
544                                    }
545                                })
546                                .detach();
547                        }
548                        Err(error) => {
549                            log::error!(
550                                "error deserializing {} request: {:?}, message: {:?}",
551                                method,
552                                error,
553                                params
554                            );
555                            let response = AnyResponse {
556                                jsonrpc: JSON_RPC_VERSION,
557                                id,
558                                result: None,
559                                error: Some(Error {
560                                    message: error.to_string(),
561                                }),
562                            };
563                            if let Some(response) = serde_json::to_vec(&response).log_err() {
564                                outbound_tx.try_send(response).ok();
565                            }
566                        }
567                    }
568                }
569            }),
570        );
571        assert!(
572            prev_handler.is_none(),
573            "registered multiple handlers for the same LSP method"
574        );
575        Subscription {
576            method,
577            notification_handlers: self.notification_handlers.clone(),
578        }
579    }
580
581    pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
582        &self.name
583    }
584
585    pub fn capabilities<'a>(self: &'a Arc<Self>) -> &'a ServerCapabilities {
586        &self.capabilities
587    }
588
589    pub fn server_id(&self) -> LanguageServerId {
590        self.server_id
591    }
592
593    pub fn root_path(&self) -> &PathBuf {
594        &self.root_path
595    }
596
597    pub fn request<T: request::Request>(
598        &self,
599        params: T::Params,
600    ) -> impl Future<Output = Result<T::Result>>
601    where
602        T::Result: 'static + Send,
603    {
604        Self::request_internal::<T>(
605            &self.next_id,
606            &self.response_handlers,
607            &self.outbound_tx,
608            params,
609        )
610    }
611
612    fn request_internal<T: request::Request>(
613        next_id: &AtomicUsize,
614        response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
615        outbound_tx: &channel::Sender<Vec<u8>>,
616        params: T::Params,
617    ) -> impl 'static + Future<Output = Result<T::Result>>
618    where
619        T::Result: 'static + Send,
620    {
621        let id = next_id.fetch_add(1, SeqCst);
622        let message = serde_json::to_vec(&Request {
623            jsonrpc: JSON_RPC_VERSION,
624            id,
625            method: T::METHOD,
626            params,
627        })
628        .unwrap();
629
630        let (tx, rx) = oneshot::channel();
631        let handle_response = response_handlers
632            .lock()
633            .as_mut()
634            .ok_or_else(|| anyhow!("server shut down"))
635            .map(|handlers| {
636                handlers.insert(
637                    id,
638                    Box::new(move |result| {
639                        let response = match result {
640                            Ok(response) => serde_json::from_str(response)
641                                .context("failed to deserialize response"),
642                            Err(error) => Err(anyhow!("{}", error.message)),
643                        };
644                        let _ = tx.send(response);
645                    }),
646                );
647            });
648
649        let send = outbound_tx
650            .try_send(message)
651            .context("failed to write to language server's stdin");
652
653        async move {
654            handle_response?;
655            send?;
656            rx.await?
657        }
658    }
659
660    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
661        Self::notify_internal::<T>(&self.outbound_tx, params)
662    }
663
664    fn notify_internal<T: notification::Notification>(
665        outbound_tx: &channel::Sender<Vec<u8>>,
666        params: T::Params,
667    ) -> Result<()> {
668        let message = serde_json::to_vec(&Notification {
669            jsonrpc: JSON_RPC_VERSION,
670            method: T::METHOD,
671            params,
672        })
673        .unwrap();
674        outbound_tx.try_send(message)?;
675        Ok(())
676    }
677}
678
679impl Drop for LanguageServer {
680    fn drop(&mut self) {
681        if let Some(shutdown) = self.shutdown() {
682            self.executor.spawn(shutdown).detach();
683        }
684    }
685}
686
687impl Subscription {
688    pub fn detach(mut self) {
689        self.method = "";
690    }
691}
692
693impl fmt::Display for LanguageServerId {
694    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695        self.0.fmt(f)
696    }
697}
698
699impl Drop for Subscription {
700    fn drop(&mut self) {
701        self.notification_handlers.lock().remove(self.method);
702    }
703}
704
705#[cfg(any(test, feature = "test-support"))]
706#[derive(Clone)]
707pub struct FakeLanguageServer {
708    pub server: Arc<LanguageServer>,
709    notifications_rx: channel::Receiver<(String, String)>,
710}
711
712#[cfg(any(test, feature = "test-support"))]
713impl LanguageServer {
714    pub fn full_capabilities() -> ServerCapabilities {
715        ServerCapabilities {
716            document_highlight_provider: Some(OneOf::Left(true)),
717            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
718            document_formatting_provider: Some(OneOf::Left(true)),
719            document_range_formatting_provider: Some(OneOf::Left(true)),
720            ..Default::default()
721        }
722    }
723
724    pub fn fake(
725        name: String,
726        capabilities: ServerCapabilities,
727        cx: AsyncAppContext,
728    ) -> (Self, FakeLanguageServer) {
729        let (stdin_writer, stdin_reader) = async_pipe::pipe();
730        let (stdout_writer, stdout_reader) = async_pipe::pipe();
731        let (notifications_tx, notifications_rx) = channel::unbounded();
732
733        let server = Self::new_internal(
734            LanguageServerId(0),
735            stdin_writer,
736            stdout_reader,
737            None,
738            Path::new("/"),
739            None,
740            cx.clone(),
741            |_| {},
742        );
743        let fake = FakeLanguageServer {
744            server: Arc::new(Self::new_internal(
745                LanguageServerId(0),
746                stdout_writer,
747                stdin_reader,
748                None,
749                Path::new("/"),
750                None,
751                cx,
752                move |msg| {
753                    notifications_tx
754                        .try_send((msg.method.to_string(), msg.params.get().to_string()))
755                        .ok();
756                },
757            )),
758            notifications_rx,
759        };
760        fake.handle_request::<request::Initialize, _, _>({
761            let capabilities = capabilities;
762            move |_, _| {
763                let capabilities = capabilities.clone();
764                let name = name.clone();
765                async move {
766                    Ok(InitializeResult {
767                        capabilities,
768                        server_info: Some(ServerInfo {
769                            name,
770                            ..Default::default()
771                        }),
772                    })
773                }
774            }
775        });
776
777        (server, fake)
778    }
779}
780
781#[cfg(any(test, feature = "test-support"))]
782impl FakeLanguageServer {
783    pub fn notify<T: notification::Notification>(&self, params: T::Params) {
784        self.server.notify::<T>(params).ok();
785    }
786
787    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
788    where
789        T: request::Request,
790        T::Result: 'static + Send,
791    {
792        self.server.request::<T>(params).await
793    }
794
795    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
796        self.try_receive_notification::<T>().await.unwrap()
797    }
798
799    pub async fn try_receive_notification<T: notification::Notification>(
800        &mut self,
801    ) -> Option<T::Params> {
802        use futures::StreamExt as _;
803
804        loop {
805            let (method, params) = self.notifications_rx.next().await?;
806            if method == T::METHOD {
807                return Some(serde_json::from_str::<T::Params>(&params).unwrap());
808            } else {
809                log::info!("skipping message in fake language server {:?}", params);
810            }
811        }
812    }
813
814    pub fn handle_request<T, F, Fut>(
815        &self,
816        mut handler: F,
817    ) -> futures::channel::mpsc::UnboundedReceiver<()>
818    where
819        T: 'static + request::Request,
820        T::Params: 'static + Send,
821        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
822        Fut: 'static + Send + Future<Output = Result<T::Result>>,
823    {
824        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
825        self.server.remove_request_handler::<T>();
826        self.server
827            .on_request::<T, _, _>(move |params, cx| {
828                let result = handler(params, cx.clone());
829                let responded_tx = responded_tx.clone();
830                async move {
831                    cx.background().simulate_random_delay().await;
832                    let result = result.await;
833                    responded_tx.unbounded_send(()).ok();
834                    result
835                }
836            })
837            .detach();
838        responded_rx
839    }
840
841    pub fn handle_notification<T, F>(
842        &self,
843        mut handler: F,
844    ) -> futures::channel::mpsc::UnboundedReceiver<()>
845    where
846        T: 'static + notification::Notification,
847        T::Params: 'static + Send,
848        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext),
849    {
850        let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded();
851        self.server.remove_notification_handler::<T>();
852        self.server
853            .on_notification::<T, _>(move |params, cx| {
854                handler(params, cx.clone());
855                handled_tx.unbounded_send(()).ok();
856            })
857            .detach();
858        handled_rx
859    }
860
861    pub fn remove_request_handler<T>(&mut self)
862    where
863        T: 'static + request::Request,
864    {
865        self.server.remove_request_handler::<T>();
866    }
867
868    pub async fn start_progress(&self, token: impl Into<String>) {
869        let token = token.into();
870        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
871            token: NumberOrString::String(token.clone()),
872        })
873        .await
874        .unwrap();
875        self.notify::<notification::Progress>(ProgressParams {
876            token: NumberOrString::String(token),
877            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
878        });
879    }
880
881    pub fn end_progress(&self, token: impl Into<String>) {
882        self.notify::<notification::Progress>(ProgressParams {
883            token: NumberOrString::String(token.into()),
884            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
885        });
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892    use gpui::TestAppContext;
893
894    #[ctor::ctor]
895    fn init_logger() {
896        if std::env::var("RUST_LOG").is_ok() {
897            env_logger::init();
898        }
899    }
900
901    #[gpui::test]
902    async fn test_fake(cx: &mut TestAppContext) {
903        let (server, mut fake) =
904            LanguageServer::fake("the-lsp".to_string(), Default::default(), cx.to_async());
905
906        let (message_tx, message_rx) = channel::unbounded();
907        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
908        server
909            .on_notification::<notification::ShowMessage, _>(move |params, _| {
910                message_tx.try_send(params).unwrap()
911            })
912            .detach();
913        server
914            .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
915                diagnostics_tx.try_send(params).unwrap()
916            })
917            .detach();
918
919        let server = server.initialize(None).await.unwrap();
920        server
921            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
922                text_document: TextDocumentItem::new(
923                    Url::from_str("file://a/b").unwrap(),
924                    "rust".to_string(),
925                    0,
926                    "".to_string(),
927                ),
928            })
929            .unwrap();
930        assert_eq!(
931            fake.receive_notification::<notification::DidOpenTextDocument>()
932                .await
933                .text_document
934                .uri
935                .as_str(),
936            "file://a/b"
937        );
938
939        fake.notify::<notification::ShowMessage>(ShowMessageParams {
940            typ: MessageType::ERROR,
941            message: "ok".to_string(),
942        });
943        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
944            uri: Url::from_str("file://b/c").unwrap(),
945            version: Some(5),
946            diagnostics: vec![],
947        });
948        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
949        assert_eq!(
950            diagnostics_rx.recv().await.unwrap().uri.as_str(),
951            "file://b/c"
952        );
953
954        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
955
956        drop(server);
957        fake.receive_notification::<notification::Exit>().await;
958    }
959}