lsp.rs

  1use anyhow::{anyhow, Context, Result};
  2use collections::HashMap;
  3use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
  4use gpui::{executor, Task};
  5use parking_lot::{Mutex, RwLock};
  6use postage::{barrier, prelude::Stream};
  7use serde::{de::DeserializeOwned, Deserialize, Serialize};
  8use serde_json::{json, value::RawValue, Value};
  9use smol::{
 10    channel,
 11    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 12    process::Command,
 13};
 14use std::{
 15    future::Future,
 16    io::Write,
 17    path::PathBuf,
 18    str::FromStr,
 19    sync::{
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21        Arc,
 22    },
 23};
 24use std::{path::Path, process::Stdio};
 25use util::TryFutureExt;
 26
 27pub use lsp_types::*;
 28
 29const JSON_RPC_VERSION: &'static str = "2.0";
 30const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 31
 32type NotificationHandler =
 33    Box<dyn Send + Sync + FnMut(Option<usize>, &str, &mut channel::Sender<Vec<u8>>) -> Result<()>>;
 34type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 35
 36pub struct LanguageServer {
 37    server_id: usize,
 38    next_id: AtomicUsize,
 39    outbound_tx: channel::Sender<Vec<u8>>,
 40    name: String,
 41    capabilities: ServerCapabilities,
 42    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 43    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 44    executor: Arc<executor::Background>,
 45    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 46    output_done_rx: Mutex<Option<barrier::Receiver>>,
 47    root_path: PathBuf,
 48    options: Option<Value>,
 49}
 50
 51pub struct Subscription {
 52    method: &'static str,
 53    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 54}
 55
 56#[derive(Serialize, Deserialize)]
 57struct Request<'a, T> {
 58    jsonrpc: &'a str,
 59    id: usize,
 60    method: &'a str,
 61    params: T,
 62}
 63
 64#[cfg(any(test, feature = "test-support"))]
 65#[derive(Deserialize)]
 66struct AnyRequest<'a> {
 67    id: usize,
 68    #[serde(borrow)]
 69    jsonrpc: &'a str,
 70    #[serde(borrow)]
 71    method: &'a str,
 72    #[serde(borrow)]
 73    params: &'a RawValue,
 74}
 75
 76#[derive(Serialize, Deserialize)]
 77struct AnyResponse<'a> {
 78    id: usize,
 79    #[serde(default)]
 80    error: Option<Error>,
 81    #[serde(borrow)]
 82    result: Option<&'a RawValue>,
 83}
 84
 85#[derive(Serialize)]
 86struct Response<T> {
 87    id: usize,
 88    result: T,
 89}
 90
 91#[derive(Serialize, Deserialize)]
 92struct Notification<'a, T> {
 93    #[serde(borrow)]
 94    jsonrpc: &'a str,
 95    #[serde(borrow)]
 96    method: &'a str,
 97    params: T,
 98}
 99
100#[derive(Deserialize)]
101struct AnyNotification<'a> {
102    #[serde(default)]
103    id: Option<usize>,
104    #[serde(borrow)]
105    method: &'a str,
106    #[serde(borrow)]
107    params: &'a RawValue,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct Error {
112    message: String,
113}
114
115impl LanguageServer {
116    pub fn new(
117        server_id: usize,
118        binary_path: &Path,
119        args: &[&str],
120        root_path: &Path,
121        options: Option<Value>,
122        background: Arc<executor::Background>,
123    ) -> Result<Self> {
124        let working_dir = if root_path.is_dir() {
125            root_path
126        } else {
127            root_path.parent().unwrap_or(Path::new("/"))
128        };
129        let mut server = Command::new(binary_path)
130            .current_dir(working_dir)
131            .args(args)
132            .stdin(Stdio::piped())
133            .stdout(Stdio::piped())
134            .stderr(Stdio::inherit())
135            .spawn()?;
136        let stdin = server.stdin.take().unwrap();
137        let stdout = server.stdout.take().unwrap();
138        let mut server =
139            Self::new_internal(server_id, stdin, stdout, root_path, options, background);
140        if let Some(name) = binary_path.file_name() {
141            server.name = name.to_string_lossy().to_string();
142        }
143        Ok(server)
144    }
145
146    fn new_internal<Stdin, Stdout>(
147        server_id: usize,
148        stdin: Stdin,
149        stdout: Stdout,
150        root_path: &Path,
151        options: Option<Value>,
152        executor: Arc<executor::Background>,
153    ) -> Self
154    where
155        Stdin: AsyncWrite + Unpin + Send + 'static,
156        Stdout: AsyncRead + Unpin + Send + 'static,
157    {
158        let mut stdin = BufWriter::new(stdin);
159        let mut stdout = BufReader::new(stdout);
160        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
161        let notification_handlers =
162            Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::default()));
163        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default()));
164        let input_task = executor.spawn(
165            {
166                let notification_handlers = notification_handlers.clone();
167                let response_handlers = response_handlers.clone();
168                let mut outbound_tx = outbound_tx.clone();
169                async move {
170                    let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
171                    let mut buffer = Vec::new();
172                    loop {
173                        buffer.clear();
174                        stdout.read_until(b'\n', &mut buffer).await?;
175                        stdout.read_until(b'\n', &mut buffer).await?;
176                        let message_len: usize = std::str::from_utf8(&buffer)?
177                            .strip_prefix(CONTENT_LEN_HEADER)
178                            .ok_or_else(|| anyhow!("invalid header"))?
179                            .trim_end()
180                            .parse()?;
181
182                        buffer.resize(message_len, 0);
183                        stdout.read_exact(&mut buffer).await?;
184                        log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
185
186                        if let Ok(AnyNotification { id, method, params }) =
187                            serde_json::from_slice(&buffer)
188                        {
189                            if let Some(handler) = notification_handlers.write().get_mut(method) {
190                                if let Err(e) = handler(id, params.get(), &mut outbound_tx) {
191                                    log::error!("error handling {} message: {:?}", method, e);
192                                }
193                            } else {
194                                log::info!(
195                                    "unhandled notification {}:\n{}",
196                                    method,
197                                    serde_json::to_string_pretty(
198                                        &Value::from_str(params.get()).unwrap()
199                                    )
200                                    .unwrap()
201                                );
202                            }
203                        } else if let Ok(AnyResponse { id, error, result }) =
204                            serde_json::from_slice(&buffer)
205                        {
206                            if let Some(handler) = response_handlers.lock().remove(&id) {
207                                if let Some(error) = error {
208                                    handler(Err(error));
209                                } else if let Some(result) = result {
210                                    handler(Ok(result.get()));
211                                } else {
212                                    handler(Ok("null"));
213                                }
214                            }
215                        } else {
216                            return Err(anyhow!(
217                                "failed to deserialize message:\n{}",
218                                std::str::from_utf8(&buffer)?
219                            ));
220                        }
221                    }
222                }
223            }
224            .log_err(),
225        );
226        let (output_done_tx, output_done_rx) = barrier::channel();
227        let output_task = executor.spawn({
228            let response_handlers = response_handlers.clone();
229            async move {
230                let _clear_response_handlers = ClearResponseHandlers(response_handlers);
231                let mut content_len_buffer = Vec::new();
232                while let Ok(message) = outbound_rx.recv().await {
233                    log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
234                    content_len_buffer.clear();
235                    write!(content_len_buffer, "{}", message.len()).unwrap();
236                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
237                    stdin.write_all(&content_len_buffer).await?;
238                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
239                    stdin.write_all(&message).await?;
240                    stdin.flush().await?;
241                }
242                drop(output_done_tx);
243                Ok(())
244            }
245            .log_err()
246        });
247
248        Self {
249            server_id,
250            notification_handlers,
251            response_handlers,
252            name: Default::default(),
253            capabilities: Default::default(),
254            next_id: Default::default(),
255            outbound_tx,
256            executor: executor.clone(),
257            io_tasks: Mutex::new(Some((input_task, output_task))),
258            output_done_rx: Mutex::new(Some(output_done_rx)),
259            root_path: root_path.to_path_buf(),
260            options,
261        }
262    }
263
264    pub async fn initialize(mut self) -> Result<Arc<Self>> {
265        let options = self.options.take();
266        let mut this = Arc::new(self);
267        let root_uri = Url::from_file_path(&this.root_path).unwrap();
268        #[allow(deprecated)]
269        let params = InitializeParams {
270            process_id: Default::default(),
271            root_path: Default::default(),
272            root_uri: Some(root_uri),
273            initialization_options: options,
274            capabilities: ClientCapabilities {
275                workspace: Some(WorkspaceClientCapabilities {
276                    configuration: Some(true),
277                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
278                        dynamic_registration: Some(true),
279                    }),
280                    ..Default::default()
281                }),
282                text_document: Some(TextDocumentClientCapabilities {
283                    definition: Some(GotoCapability {
284                        link_support: Some(true),
285                        ..Default::default()
286                    }),
287                    code_action: Some(CodeActionClientCapabilities {
288                        code_action_literal_support: Some(CodeActionLiteralSupport {
289                            code_action_kind: CodeActionKindLiteralSupport {
290                                value_set: vec![
291                                    CodeActionKind::REFACTOR.as_str().into(),
292                                    CodeActionKind::QUICKFIX.as_str().into(),
293                                ],
294                            },
295                        }),
296                        data_support: Some(true),
297                        resolve_support: Some(CodeActionCapabilityResolveSupport {
298                            properties: vec!["edit".to_string()],
299                        }),
300                        ..Default::default()
301                    }),
302                    completion: Some(CompletionClientCapabilities {
303                        completion_item: Some(CompletionItemCapability {
304                            snippet_support: Some(true),
305                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
306                                properties: vec!["additionalTextEdits".to_string()],
307                            }),
308                            ..Default::default()
309                        }),
310                        ..Default::default()
311                    }),
312                    ..Default::default()
313                }),
314                experimental: Some(json!({
315                    "serverStatusNotification": true,
316                })),
317                window: Some(WindowClientCapabilities {
318                    work_done_progress: Some(true),
319                    ..Default::default()
320                }),
321                ..Default::default()
322            },
323            trace: Default::default(),
324            workspace_folders: Default::default(),
325            client_info: Default::default(),
326            locale: Default::default(),
327        };
328
329        let response = this.request::<request::Initialize>(params).await?;
330        {
331            let this = Arc::get_mut(&mut this).unwrap();
332            if let Some(info) = response.server_info {
333                this.name = info.name;
334            }
335            this.capabilities = response.capabilities;
336        }
337        this.notify::<notification::Initialized>(InitializedParams {})?;
338        Ok(this)
339    }
340
341    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
342        if let Some(tasks) = self.io_tasks.lock().take() {
343            let response_handlers = self.response_handlers.clone();
344            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
345            let outbound_tx = self.outbound_tx.clone();
346            let mut output_done = self.output_done_rx.lock().take().unwrap();
347            let shutdown_request = Self::request_internal::<request::Shutdown>(
348                &next_id,
349                &response_handlers,
350                &outbound_tx,
351                (),
352            );
353            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
354            outbound_tx.close();
355            Some(
356                async move {
357                    log::debug!("language server shutdown started");
358                    shutdown_request.await?;
359                    response_handlers.lock().clear();
360                    exit?;
361                    output_done.recv().await;
362                    log::debug!("language server shutdown finished");
363                    drop(tasks);
364                    Ok(())
365                }
366                .log_err(),
367            )
368        } else {
369            None
370        }
371    }
372
373    pub fn on_notification<T, F>(&mut self, f: F) -> Subscription
374    where
375        T: notification::Notification,
376        F: 'static + Send + Sync + FnMut(T::Params),
377    {
378        self.on_custom_notification(T::METHOD, f)
379    }
380
381    pub fn on_request<T, F>(&mut self, f: F) -> Subscription
382    where
383        T: request::Request,
384        F: 'static + Send + Sync + FnMut(T::Params) -> Result<T::Result>,
385    {
386        self.on_custom_request(T::METHOD, f)
387    }
388
389    pub fn on_custom_notification<Params, F>(
390        &mut self,
391        method: &'static str,
392        mut f: F,
393    ) -> Subscription
394    where
395        F: 'static + Send + Sync + FnMut(Params),
396        Params: DeserializeOwned,
397    {
398        let prev_handler = self.notification_handlers.write().insert(
399            method,
400            Box::new(move |_, params, _| {
401                let params = serde_json::from_str(params)?;
402                f(params);
403                Ok(())
404            }),
405        );
406        assert!(
407            prev_handler.is_none(),
408            "registered multiple handlers for the same LSP method"
409        );
410        Subscription {
411            method,
412            notification_handlers: self.notification_handlers.clone(),
413        }
414    }
415
416    pub fn on_custom_request<Params, Res, F>(
417        &mut self,
418        method: &'static str,
419        mut f: F,
420    ) -> Subscription
421    where
422        F: 'static + Send + Sync + FnMut(Params) -> Result<Res>,
423        Params: DeserializeOwned,
424        Res: Serialize,
425    {
426        let prev_handler = self.notification_handlers.write().insert(
427            method,
428            Box::new(move |id, params, tx| {
429                if let Some(id) = id {
430                    let params = serde_json::from_str(params)?;
431                    let result = f(params)?;
432                    let response = serde_json::to_vec(&Response { id, result })?;
433                    tx.try_send(response)?;
434                }
435                Ok(())
436            }),
437        );
438        assert!(
439            prev_handler.is_none(),
440            "registered multiple handlers for the same LSP method"
441        );
442        Subscription {
443            method,
444            notification_handlers: self.notification_handlers.clone(),
445        }
446    }
447
448    pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
449        &self.name
450    }
451
452    pub fn capabilities<'a>(self: &'a Arc<Self>) -> &'a ServerCapabilities {
453        &self.capabilities
454    }
455
456    pub fn server_id(&self) -> usize {
457        self.server_id
458    }
459
460    pub fn request<T: request::Request>(
461        self: &Arc<Self>,
462        params: T::Params,
463    ) -> impl Future<Output = Result<T::Result>>
464    where
465        T::Result: 'static + Send,
466    {
467        Self::request_internal::<T>(
468            &self.next_id,
469            &self.response_handlers,
470            &self.outbound_tx,
471            params,
472        )
473    }
474
475    fn request_internal<T: request::Request>(
476        next_id: &AtomicUsize,
477        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
478        outbound_tx: &channel::Sender<Vec<u8>>,
479        params: T::Params,
480    ) -> impl 'static + Future<Output = Result<T::Result>>
481    where
482        T::Result: 'static + Send,
483    {
484        let id = next_id.fetch_add(1, SeqCst);
485        let message = serde_json::to_vec(&Request {
486            jsonrpc: JSON_RPC_VERSION,
487            id,
488            method: T::METHOD,
489            params,
490        })
491        .unwrap();
492
493        let send = outbound_tx
494            .try_send(message)
495            .context("failed to write to language server's stdin");
496
497        let (tx, rx) = oneshot::channel();
498        response_handlers.lock().insert(
499            id,
500            Box::new(move |result| {
501                let response = match result {
502                    Ok(response) => {
503                        serde_json::from_str(response).context("failed to deserialize response")
504                    }
505                    Err(error) => Err(anyhow!("{}", error.message)),
506                };
507                let _ = tx.send(response);
508            }),
509        );
510
511        async move {
512            send?;
513            rx.await?
514        }
515    }
516
517    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
518        Self::notify_internal::<T>(&self.outbound_tx, params)
519    }
520
521    fn notify_internal<T: notification::Notification>(
522        outbound_tx: &channel::Sender<Vec<u8>>,
523        params: T::Params,
524    ) -> Result<()> {
525        let message = serde_json::to_vec(&Notification {
526            jsonrpc: JSON_RPC_VERSION,
527            method: T::METHOD,
528            params,
529        })
530        .unwrap();
531        outbound_tx.try_send(message)?;
532        Ok(())
533    }
534}
535
536impl Drop for LanguageServer {
537    fn drop(&mut self) {
538        if let Some(shutdown) = self.shutdown() {
539            self.executor.spawn(shutdown).detach();
540        }
541    }
542}
543
544impl Subscription {
545    pub fn detach(mut self) {
546        self.method = "";
547    }
548}
549
550impl Drop for Subscription {
551    fn drop(&mut self) {
552        self.notification_handlers.write().remove(self.method);
553    }
554}
555
556#[cfg(any(test, feature = "test-support"))]
557pub struct FakeLanguageServer {
558    handlers: FakeLanguageServerHandlers,
559    outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
560    incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
561    _input_task: Task<Result<()>>,
562    _output_task: Task<Result<()>>,
563}
564
565#[cfg(any(test, feature = "test-support"))]
566type FakeLanguageServerHandlers = Arc<
567    Mutex<
568        HashMap<
569            &'static str,
570            Box<
571                dyn Send
572                    + FnMut(
573                        usize,
574                        &[u8],
575                        gpui::AsyncAppContext,
576                    ) -> futures::future::BoxFuture<'static, Vec<u8>>,
577            >,
578        >,
579    >,
580>;
581
582#[cfg(any(test, feature = "test-support"))]
583impl LanguageServer {
584    pub fn full_capabilities() -> ServerCapabilities {
585        ServerCapabilities {
586            document_highlight_provider: Some(OneOf::Left(true)),
587            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
588            document_formatting_provider: Some(OneOf::Left(true)),
589            document_range_formatting_provider: Some(OneOf::Left(true)),
590            ..Default::default()
591        }
592    }
593
594    pub fn fake(cx: &mut gpui::MutableAppContext) -> (Self, FakeLanguageServer) {
595        Self::fake_with_capabilities(Self::full_capabilities(), cx)
596    }
597
598    pub fn fake_with_capabilities(
599        capabilities: ServerCapabilities,
600        cx: &mut gpui::MutableAppContext,
601    ) -> (Self, FakeLanguageServer) {
602        let (stdin_writer, stdin_reader) = async_pipe::pipe();
603        let (stdout_writer, stdout_reader) = async_pipe::pipe();
604
605        let mut fake = FakeLanguageServer::new(stdin_reader, stdout_writer, cx);
606        fake.handle_request::<request::Initialize, _, _>({
607            let capabilities = capabilities.clone();
608            move |_, _| {
609                let capabilities = capabilities.clone();
610                async move {
611                    InitializeResult {
612                        capabilities,
613                        ..Default::default()
614                    }
615                }
616            }
617        });
618
619        let executor = cx.background().clone();
620        let server = Self::new_internal(
621            0,
622            stdin_writer,
623            stdout_reader,
624            Path::new("/"),
625            None,
626            executor,
627        );
628        (server, fake)
629    }
630}
631
632#[cfg(any(test, feature = "test-support"))]
633impl FakeLanguageServer {
634    fn new(
635        stdin: async_pipe::PipeReader,
636        stdout: async_pipe::PipeWriter,
637        cx: &mut gpui::MutableAppContext,
638    ) -> Self {
639        use futures::StreamExt as _;
640
641        let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded();
642        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
643        let handlers = FakeLanguageServerHandlers::default();
644
645        let input_task = cx.spawn(|cx| {
646            let handlers = handlers.clone();
647            let outgoing_tx = outgoing_tx.clone();
648            async move {
649                let mut buffer = Vec::new();
650                let mut stdin = smol::io::BufReader::new(stdin);
651                while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
652                    cx.background().simulate_random_delay().await;
653
654                    if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
655                        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
656
657                        let response;
658                        if let Some(handler) = handlers.lock().get_mut(request.method) {
659                            response =
660                                handler(request.id, request.params.get().as_bytes(), cx.clone())
661                                    .await;
662                            log::debug!("handled lsp request. method:{}", request.method);
663                        } else {
664                            response = serde_json::to_vec(&AnyResponse {
665                                id: request.id,
666                                error: Some(Error {
667                                    message: "no handler".to_string(),
668                                }),
669                                result: None,
670                            })
671                            .unwrap();
672                            log::debug!("unhandled lsp request. method:{}", request.method);
673                        }
674                        outgoing_tx.unbounded_send(response)?;
675                    } else {
676                        incoming_tx.unbounded_send(buffer.clone())?;
677                    }
678                }
679                Ok::<_, anyhow::Error>(())
680            }
681        });
682
683        let output_task = cx.background().spawn(async move {
684            let mut stdout = smol::io::BufWriter::new(stdout);
685            while let Some(message) = outgoing_rx.next().await {
686                stdout.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
687                stdout
688                    .write_all((format!("{}", message.len())).as_bytes())
689                    .await?;
690                stdout.write_all("\r\n\r\n".as_bytes()).await?;
691                stdout.write_all(&message).await?;
692                stdout.flush().await?;
693            }
694            Ok(())
695        });
696
697        Self {
698            outgoing_tx,
699            incoming_rx,
700            handlers,
701            _input_task: input_task,
702            _output_task: output_task,
703        }
704    }
705
706    pub fn notify<T: notification::Notification>(&mut self, params: T::Params) {
707        let message = serde_json::to_vec(&Notification {
708            jsonrpc: JSON_RPC_VERSION,
709            method: T::METHOD,
710            params,
711        })
712        .unwrap();
713        self.outgoing_tx.unbounded_send(message).unwrap();
714    }
715
716    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
717        use futures::StreamExt as _;
718
719        loop {
720            let bytes = self.incoming_rx.next().await.unwrap();
721            if let Ok(notification) = serde_json::from_slice::<Notification<T::Params>>(&bytes) {
722                assert_eq!(notification.method, T::METHOD);
723                return notification.params;
724            } else {
725                log::info!(
726                    "skipping message in fake language server {:?}",
727                    std::str::from_utf8(&bytes)
728                );
729            }
730        }
731    }
732
733    pub fn handle_request<T, F, Fut>(
734        &mut self,
735        mut handler: F,
736    ) -> futures::channel::mpsc::UnboundedReceiver<()>
737    where
738        T: 'static + request::Request,
739        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
740        Fut: 'static + Send + Future<Output = T::Result>,
741    {
742        use futures::FutureExt as _;
743
744        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
745        self.handlers.lock().insert(
746            T::METHOD,
747            Box::new(move |id, params, cx| {
748                let result = handler(serde_json::from_slice::<T::Params>(params).unwrap(), cx);
749                let responded_tx = responded_tx.clone();
750                async move {
751                    let result = result.await;
752                    let result = serde_json::to_string(&result).unwrap();
753                    let result = serde_json::from_str::<&RawValue>(&result).unwrap();
754                    let response = AnyResponse {
755                        id,
756                        error: None,
757                        result: Some(result),
758                    };
759                    responded_tx.unbounded_send(()).ok();
760                    serde_json::to_vec(&response).unwrap()
761                }
762                .boxed()
763            }),
764        );
765        responded_rx
766    }
767
768    pub fn remove_request_handler<T>(&mut self)
769    where
770        T: 'static + request::Request,
771    {
772        self.handlers.lock().remove(T::METHOD);
773    }
774
775    pub async fn start_progress(&mut self, token: impl Into<String>) {
776        self.notify::<notification::Progress>(ProgressParams {
777            token: NumberOrString::String(token.into()),
778            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
779        });
780    }
781
782    pub async fn end_progress(&mut self, token: impl Into<String>) {
783        self.notify::<notification::Progress>(ProgressParams {
784            token: NumberOrString::String(token.into()),
785            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
786        });
787    }
788
789    async fn receive(
790        stdin: &mut smol::io::BufReader<async_pipe::PipeReader>,
791        buffer: &mut Vec<u8>,
792    ) -> Result<()> {
793        buffer.clear();
794        stdin.read_until(b'\n', buffer).await?;
795        stdin.read_until(b'\n', buffer).await?;
796        let message_len: usize = std::str::from_utf8(buffer)
797            .unwrap()
798            .strip_prefix(CONTENT_LEN_HEADER)
799            .ok_or_else(|| anyhow!("invalid content length header"))?
800            .trim_end()
801            .parse()
802            .unwrap();
803        buffer.resize(message_len, 0);
804        stdin.read_exact(buffer).await?;
805        Ok(())
806    }
807}
808
809struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
810
811impl Drop for ClearResponseHandlers {
812    fn drop(&mut self) {
813        self.0.lock().clear();
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use gpui::TestAppContext;
821
822    #[ctor::ctor]
823    fn init_logger() {
824        if std::env::var("RUST_LOG").is_ok() {
825            env_logger::init();
826        }
827    }
828
829    #[gpui::test]
830    async fn test_fake(cx: &mut TestAppContext) {
831        let (mut server, mut fake) = cx.update(LanguageServer::fake);
832
833        let (message_tx, message_rx) = channel::unbounded();
834        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
835        server
836            .on_notification::<notification::ShowMessage, _>(move |params| {
837                message_tx.try_send(params).unwrap()
838            })
839            .detach();
840        server
841            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
842                diagnostics_tx.try_send(params).unwrap()
843            })
844            .detach();
845
846        let server = server.initialize().await.unwrap();
847        server
848            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
849                text_document: TextDocumentItem::new(
850                    Url::from_str("file://a/b").unwrap(),
851                    "rust".to_string(),
852                    0,
853                    "".to_string(),
854                ),
855            })
856            .unwrap();
857        assert_eq!(
858            fake.receive_notification::<notification::DidOpenTextDocument>()
859                .await
860                .text_document
861                .uri
862                .as_str(),
863            "file://a/b"
864        );
865
866        fake.notify::<notification::ShowMessage>(ShowMessageParams {
867            typ: MessageType::ERROR,
868            message: "ok".to_string(),
869        });
870        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
871            uri: Url::from_str("file://b/c").unwrap(),
872            version: Some(5),
873            diagnostics: vec![],
874        });
875        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
876        assert_eq!(
877            diagnostics_rx.recv().await.unwrap().uri.as_str(),
878            "file://b/c"
879        );
880
881        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move {});
882
883        drop(server);
884        fake.receive_notification::<notification::Exit>().await;
885    }
886}