lsp.rs

   1use log::warn;
   2pub use lsp_types::request::*;
   3pub use lsp_types::*;
   4
   5use anyhow::{anyhow, Context, Result};
   6use collections::HashMap;
   7use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite, FutureExt};
   8use gpui::{executor, AsyncAppContext, Task};
   9use parking_lot::Mutex;
  10use postage::{barrier, prelude::Stream};
  11use serde::{de::DeserializeOwned, Deserialize, Serialize};
  12use serde_json::{json, value::RawValue, Value};
  13use smol::{
  14    channel,
  15    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
  16    process::{self, Child},
  17};
  18use std::{
  19    ffi::OsString,
  20    fmt,
  21    future::Future,
  22    io::Write,
  23    path::PathBuf,
  24    str::{self, FromStr as _},
  25    sync::{
  26        atomic::{AtomicUsize, Ordering::SeqCst},
  27        Arc, Weak,
  28    },
  29    time::{Duration, Instant},
  30};
  31use std::{path::Path, process::Stdio};
  32use util::{ResultExt, TryFutureExt};
  33
  34const JSON_RPC_VERSION: &str = "2.0";
  35const CONTENT_LEN_HEADER: &str = "Content-Length: ";
  36const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  37
  38type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
  39type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
  40type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
  41
  42#[derive(Debug, Clone, Copy)]
  43pub enum IoKind {
  44    StdOut,
  45    StdIn,
  46    StdErr,
  47}
  48
  49#[derive(Debug, Clone, Deserialize)]
  50pub struct LanguageServerBinary {
  51    pub path: PathBuf,
  52    pub arguments: Vec<OsString>,
  53}
  54
  55pub struct LanguageServer {
  56    server_id: LanguageServerId,
  57    next_id: AtomicUsize,
  58    outbound_tx: channel::Sender<String>,
  59    name: String,
  60    capabilities: ServerCapabilities,
  61    code_action_kinds: Option<Vec<CodeActionKind>>,
  62    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
  63    response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
  64    io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
  65    executor: Arc<executor::Background>,
  66    #[allow(clippy::type_complexity)]
  67    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
  68    output_done_rx: Mutex<Option<barrier::Receiver>>,
  69    root_path: PathBuf,
  70    _server: Option<Mutex<Child>>,
  71}
  72
  73#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
  74#[repr(transparent)]
  75pub struct LanguageServerId(pub usize);
  76
  77pub enum Subscription {
  78    Notification {
  79        method: &'static str,
  80        notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>,
  81    },
  82    Io {
  83        id: usize,
  84        io_handlers: Option<Weak<Mutex<HashMap<usize, IoHandler>>>>,
  85    },
  86}
  87
  88#[derive(Serialize, Deserialize)]
  89pub struct Request<'a, T> {
  90    jsonrpc: &'static str,
  91    id: usize,
  92    method: &'a str,
  93    params: T,
  94}
  95
  96#[derive(Serialize, Deserialize)]
  97struct AnyResponse<'a> {
  98    jsonrpc: &'a str,
  99    id: usize,
 100    #[serde(default)]
 101    error: Option<Error>,
 102    #[serde(borrow)]
 103    result: Option<&'a RawValue>,
 104}
 105
 106#[derive(Serialize)]
 107struct Response<T> {
 108    jsonrpc: &'static str,
 109    id: usize,
 110    result: Option<T>,
 111    error: Option<Error>,
 112}
 113
 114#[derive(Serialize, Deserialize)]
 115struct Notification<'a, T> {
 116    jsonrpc: &'static str,
 117    #[serde(borrow)]
 118    method: &'a str,
 119    params: T,
 120}
 121
 122#[derive(Debug, Clone, Deserialize)]
 123struct AnyNotification<'a> {
 124    #[serde(default)]
 125    id: Option<usize>,
 126    #[serde(borrow)]
 127    method: &'a str,
 128    #[serde(borrow, default)]
 129    params: Option<&'a RawValue>,
 130}
 131
 132#[derive(Debug, Serialize, Deserialize)]
 133struct Error {
 134    message: String,
 135}
 136
 137impl LanguageServer {
 138    pub fn new(
 139        server_id: LanguageServerId,
 140        binary: LanguageServerBinary,
 141        root_path: &Path,
 142        code_action_kinds: Option<Vec<CodeActionKind>>,
 143        cx: AsyncAppContext,
 144    ) -> Result<Self> {
 145        let working_dir = if root_path.is_dir() {
 146            root_path
 147        } else {
 148            root_path.parent().unwrap_or_else(|| Path::new("/"))
 149        };
 150
 151        let mut server = process::Command::new(&binary.path)
 152            .current_dir(working_dir)
 153            .args(binary.arguments)
 154            .stdin(Stdio::piped())
 155            .stdout(Stdio::piped())
 156            .stderr(Stdio::piped())
 157            .kill_on_drop(true)
 158            .spawn()?;
 159
 160        let stdin = server.stdin.take().unwrap();
 161        let stdout = server.stdout.take().unwrap();
 162        let stderr = server.stderr.take().unwrap();
 163        let mut server = Self::new_internal(
 164            server_id.clone(),
 165            stdin,
 166            stdout,
 167            Some(stderr),
 168            Some(server),
 169            root_path,
 170            code_action_kinds,
 171            cx,
 172            move |notification| {
 173                log::info!(
 174                    "{} unhandled notification {}:\n{}",
 175                    server_id,
 176                    notification.method,
 177                    serde_json::to_string_pretty(
 178                        &notification
 179                            .params
 180                            .and_then(|params| Value::from_str(params.get()).ok())
 181                            .unwrap_or(Value::Null)
 182                    )
 183                    .unwrap(),
 184                );
 185            },
 186        );
 187
 188        if let Some(name) = binary.path.file_name() {
 189            server.name = name.to_string_lossy().to_string();
 190        }
 191
 192        Ok(server)
 193    }
 194
 195    fn new_internal<Stdin, Stdout, Stderr, F>(
 196        server_id: LanguageServerId,
 197        stdin: Stdin,
 198        stdout: Stdout,
 199        stderr: Option<Stderr>,
 200        server: Option<Child>,
 201        root_path: &Path,
 202        code_action_kinds: Option<Vec<CodeActionKind>>,
 203        cx: AsyncAppContext,
 204        on_unhandled_notification: F,
 205    ) -> Self
 206    where
 207        Stdin: AsyncWrite + Unpin + Send + 'static,
 208        Stdout: AsyncRead + Unpin + Send + 'static,
 209        Stderr: AsyncRead + Unpin + Send + 'static,
 210        F: FnMut(AnyNotification) + 'static + Send + Clone,
 211    {
 212        let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
 213        let (output_done_tx, output_done_rx) = barrier::channel();
 214        let notification_handlers =
 215            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
 216        let response_handlers =
 217            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
 218        let io_handlers = Arc::new(Mutex::new(HashMap::default()));
 219
 220        let stdout_input_task = cx.spawn(|cx| {
 221            {
 222                Self::handle_input(
 223                    stdout,
 224                    on_unhandled_notification.clone(),
 225                    notification_handlers.clone(),
 226                    response_handlers.clone(),
 227                    io_handlers.clone(),
 228                    cx,
 229                )
 230            }
 231            .log_err()
 232        });
 233        let stderr_input_task = stderr
 234            .map(|stderr| cx.spawn(|_| Self::handle_stderr(stderr, io_handlers.clone()).log_err()))
 235            .unwrap_or_else(|| Task::Ready(Some(None)));
 236        let input_task = cx.spawn(|_| async move {
 237            let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
 238            stdout.or(stderr)
 239        });
 240        let output_task = cx.background().spawn({
 241            Self::handle_output(
 242                stdin,
 243                outbound_rx,
 244                output_done_tx,
 245                response_handlers.clone(),
 246                io_handlers.clone(),
 247            )
 248            .log_err()
 249        });
 250
 251        Self {
 252            server_id,
 253            notification_handlers,
 254            response_handlers,
 255            io_handlers,
 256            name: Default::default(),
 257            capabilities: Default::default(),
 258            code_action_kinds,
 259            next_id: Default::default(),
 260            outbound_tx,
 261            executor: cx.background(),
 262            io_tasks: Mutex::new(Some((input_task, output_task))),
 263            output_done_rx: Mutex::new(Some(output_done_rx)),
 264            root_path: root_path.to_path_buf(),
 265            _server: server.map(|server| Mutex::new(server)),
 266        }
 267    }
 268
 269    pub fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
 270        self.code_action_kinds.clone()
 271    }
 272
 273    async fn handle_input<Stdout, F>(
 274        stdout: Stdout,
 275        mut on_unhandled_notification: F,
 276        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 277        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
 278        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
 279        cx: AsyncAppContext,
 280    ) -> anyhow::Result<()>
 281    where
 282        Stdout: AsyncRead + Unpin + Send + 'static,
 283        F: FnMut(AnyNotification) + 'static + Send,
 284    {
 285        let mut stdout = BufReader::new(stdout);
 286        let _clear_response_handlers = util::defer({
 287            let response_handlers = response_handlers.clone();
 288            move || {
 289                response_handlers.lock().take();
 290            }
 291        });
 292        let mut buffer = Vec::new();
 293        loop {
 294            buffer.clear();
 295            stdout.read_until(b'\n', &mut buffer).await?;
 296            stdout.read_until(b'\n', &mut buffer).await?;
 297            let header = std::str::from_utf8(&buffer)?;
 298            let message_len: usize = header
 299                .strip_prefix(CONTENT_LEN_HEADER)
 300                .ok_or_else(|| anyhow!("invalid LSP message header {header:?}"))?
 301                .trim_end()
 302                .parse()?;
 303
 304            buffer.resize(message_len, 0);
 305            stdout.read_exact(&mut buffer).await?;
 306
 307            if let Ok(message) = str::from_utf8(&buffer) {
 308                log::trace!("incoming message: {}", message);
 309                for handler in io_handlers.lock().values_mut() {
 310                    handler(IoKind::StdOut, message);
 311                }
 312            }
 313
 314            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
 315                if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
 316                    handler(
 317                        msg.id,
 318                        &msg.params.map(|params| params.get()).unwrap_or("null"),
 319                        cx.clone(),
 320                    );
 321                } else {
 322                    on_unhandled_notification(msg);
 323                }
 324            } else if let Ok(AnyResponse {
 325                id, error, result, ..
 326            }) = serde_json::from_slice(&buffer)
 327            {
 328                if let Some(handler) = response_handlers
 329                    .lock()
 330                    .as_mut()
 331                    .and_then(|handlers| handlers.remove(&id))
 332                {
 333                    if let Some(error) = error {
 334                        handler(Err(error));
 335                    } else if let Some(result) = result {
 336                        handler(Ok(result.get().into()));
 337                    } else {
 338                        handler(Ok("null".into()));
 339                    }
 340                }
 341            } else {
 342                warn!(
 343                    "failed to deserialize LSP message:\n{}",
 344                    std::str::from_utf8(&buffer)?
 345                );
 346            }
 347
 348            // Don't starve the main thread when receiving lots of messages at once.
 349            smol::future::yield_now().await;
 350        }
 351    }
 352
 353    async fn handle_stderr<Stderr>(
 354        stderr: Stderr,
 355        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
 356    ) -> anyhow::Result<()>
 357    where
 358        Stderr: AsyncRead + Unpin + Send + 'static,
 359    {
 360        let mut stderr = BufReader::new(stderr);
 361        let mut buffer = Vec::new();
 362        loop {
 363            buffer.clear();
 364            stderr.read_until(b'\n', &mut buffer).await?;
 365            if let Ok(message) = str::from_utf8(&buffer) {
 366                log::trace!("incoming stderr message:{message}");
 367                for handler in io_handlers.lock().values_mut() {
 368                    handler(IoKind::StdErr, message);
 369                }
 370            }
 371
 372            // Don't starve the main thread when receiving lots of messages at once.
 373            smol::future::yield_now().await;
 374        }
 375    }
 376
 377    async fn handle_output<Stdin>(
 378        stdin: Stdin,
 379        outbound_rx: channel::Receiver<String>,
 380        output_done_tx: barrier::Sender,
 381        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
 382        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
 383    ) -> anyhow::Result<()>
 384    where
 385        Stdin: AsyncWrite + Unpin + Send + 'static,
 386    {
 387        let mut stdin = BufWriter::new(stdin);
 388        let _clear_response_handlers = util::defer({
 389            let response_handlers = response_handlers.clone();
 390            move || {
 391                response_handlers.lock().take();
 392            }
 393        });
 394        let mut content_len_buffer = Vec::new();
 395        while let Ok(message) = outbound_rx.recv().await {
 396            log::trace!("outgoing message:{}", message);
 397            for handler in io_handlers.lock().values_mut() {
 398                handler(IoKind::StdIn, &message);
 399            }
 400
 401            content_len_buffer.clear();
 402            write!(content_len_buffer, "{}", message.len()).unwrap();
 403            stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
 404            stdin.write_all(&content_len_buffer).await?;
 405            stdin.write_all("\r\n\r\n".as_bytes()).await?;
 406            stdin.write_all(message.as_bytes()).await?;
 407            stdin.flush().await?;
 408        }
 409        drop(output_done_tx);
 410        Ok(())
 411    }
 412
 413    /// Initializes a language server.
 414    /// Note that `options` is used directly to construct [`InitializeParams`],
 415    /// which is why it is owned.
 416    pub async fn initialize(mut self, options: Option<Value>) -> Result<Arc<Self>> {
 417        let root_uri = Url::from_file_path(&self.root_path).unwrap();
 418        #[allow(deprecated)]
 419        let params = InitializeParams {
 420            process_id: Default::default(),
 421            root_path: Default::default(),
 422            root_uri: Some(root_uri.clone()),
 423            initialization_options: options,
 424            capabilities: ClientCapabilities {
 425                workspace: Some(WorkspaceClientCapabilities {
 426                    configuration: Some(true),
 427                    did_change_watched_files: Some(DidChangeWatchedFilesClientCapabilities {
 428                        dynamic_registration: Some(true),
 429                        relative_pattern_support: Some(true),
 430                    }),
 431                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
 432                        dynamic_registration: Some(true),
 433                    }),
 434                    workspace_folders: Some(true),
 435                    symbol: Some(WorkspaceSymbolClientCapabilities {
 436                        resolve_support: None,
 437                        ..WorkspaceSymbolClientCapabilities::default()
 438                    }),
 439                    inlay_hint: Some(InlayHintWorkspaceClientCapabilities {
 440                        refresh_support: Some(true),
 441                    }),
 442                    ..Default::default()
 443                }),
 444                text_document: Some(TextDocumentClientCapabilities {
 445                    definition: Some(GotoCapability {
 446                        link_support: Some(true),
 447                        ..Default::default()
 448                    }),
 449                    code_action: Some(CodeActionClientCapabilities {
 450                        code_action_literal_support: Some(CodeActionLiteralSupport {
 451                            code_action_kind: CodeActionKindLiteralSupport {
 452                                value_set: vec![
 453                                    CodeActionKind::REFACTOR.as_str().into(),
 454                                    CodeActionKind::QUICKFIX.as_str().into(),
 455                                    CodeActionKind::SOURCE.as_str().into(),
 456                                ],
 457                            },
 458                        }),
 459                        data_support: Some(true),
 460                        resolve_support: Some(CodeActionCapabilityResolveSupport {
 461                            properties: vec!["edit".to_string(), "command".to_string()],
 462                        }),
 463                        ..Default::default()
 464                    }),
 465                    completion: Some(CompletionClientCapabilities {
 466                        completion_item: Some(CompletionItemCapability {
 467                            snippet_support: Some(true),
 468                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
 469                                properties: vec![
 470                                    "documentation".to_string(),
 471                                    "additionalTextEdits".to_string(),
 472                                ],
 473                            }),
 474                            ..Default::default()
 475                        }),
 476                        completion_list: Some(CompletionListCapability {
 477                            item_defaults: Some(vec![
 478                                "commitCharacters".to_owned(),
 479                                "editRange".to_owned(),
 480                                "insertTextMode".to_owned(),
 481                                "data".to_owned(),
 482                            ]),
 483                        }),
 484                        ..Default::default()
 485                    }),
 486                    rename: Some(RenameClientCapabilities {
 487                        prepare_support: Some(true),
 488                        ..Default::default()
 489                    }),
 490                    hover: Some(HoverClientCapabilities {
 491                        content_format: Some(vec![MarkupKind::Markdown]),
 492                        ..Default::default()
 493                    }),
 494                    inlay_hint: Some(InlayHintClientCapabilities {
 495                        resolve_support: Some(InlayHintResolveClientCapabilities {
 496                            properties: vec![
 497                                "textEdits".to_string(),
 498                                "tooltip".to_string(),
 499                                "label.tooltip".to_string(),
 500                                "label.location".to_string(),
 501                                "label.command".to_string(),
 502                            ],
 503                        }),
 504                        dynamic_registration: Some(false),
 505                    }),
 506                    ..Default::default()
 507                }),
 508                experimental: Some(json!({
 509                    "serverStatusNotification": true,
 510                })),
 511                window: Some(WindowClientCapabilities {
 512                    work_done_progress: Some(true),
 513                    ..Default::default()
 514                }),
 515                ..Default::default()
 516            },
 517            trace: Default::default(),
 518            workspace_folders: Some(vec![WorkspaceFolder {
 519                uri: root_uri,
 520                name: Default::default(),
 521            }]),
 522            client_info: Default::default(),
 523            locale: Default::default(),
 524        };
 525
 526        let response = self.request::<request::Initialize>(params).await?;
 527        if let Some(info) = response.server_info {
 528            self.name = info.name;
 529        }
 530        self.capabilities = response.capabilities;
 531
 532        self.notify::<notification::Initialized>(InitializedParams {})?;
 533        Ok(Arc::new(self))
 534    }
 535
 536    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
 537        if let Some(tasks) = self.io_tasks.lock().take() {
 538            let response_handlers = self.response_handlers.clone();
 539            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
 540            let outbound_tx = self.outbound_tx.clone();
 541            let executor = self.executor.clone();
 542            let mut output_done = self.output_done_rx.lock().take().unwrap();
 543            let shutdown_request = Self::request_internal::<request::Shutdown>(
 544                &next_id,
 545                &response_handlers,
 546                &outbound_tx,
 547                &executor,
 548                (),
 549            );
 550            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
 551            outbound_tx.close();
 552            Some(
 553                async move {
 554                    log::debug!("language server shutdown started");
 555                    shutdown_request.await?;
 556                    response_handlers.lock().take();
 557                    exit?;
 558                    output_done.recv().await;
 559                    log::debug!("language server shutdown finished");
 560                    drop(tasks);
 561                    anyhow::Ok(())
 562                }
 563                .log_err(),
 564            )
 565        } else {
 566            None
 567        }
 568    }
 569
 570    #[must_use]
 571    pub fn on_notification<T, F>(&self, f: F) -> Subscription
 572    where
 573        T: notification::Notification,
 574        F: 'static + Send + FnMut(T::Params, AsyncAppContext),
 575    {
 576        self.on_custom_notification(T::METHOD, f)
 577    }
 578
 579    #[must_use]
 580    pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
 581    where
 582        T: request::Request,
 583        T::Params: 'static + Send,
 584        F: 'static + Send + FnMut(T::Params, AsyncAppContext) -> Fut,
 585        Fut: 'static + Future<Output = Result<T::Result>>,
 586    {
 587        self.on_custom_request(T::METHOD, f)
 588    }
 589
 590    #[must_use]
 591    pub fn on_io<F>(&self, f: F) -> Subscription
 592    where
 593        F: 'static + Send + FnMut(IoKind, &str),
 594    {
 595        let id = self.next_id.fetch_add(1, SeqCst);
 596        self.io_handlers.lock().insert(id, Box::new(f));
 597        Subscription::Io {
 598            id,
 599            io_handlers: Some(Arc::downgrade(&self.io_handlers)),
 600        }
 601    }
 602
 603    pub fn remove_request_handler<T: request::Request>(&self) {
 604        self.notification_handlers.lock().remove(T::METHOD);
 605    }
 606
 607    pub fn remove_notification_handler<T: notification::Notification>(&self) {
 608        self.notification_handlers.lock().remove(T::METHOD);
 609    }
 610
 611    pub fn has_notification_handler<T: notification::Notification>(&self) -> bool {
 612        self.notification_handlers.lock().contains_key(T::METHOD)
 613    }
 614
 615    #[must_use]
 616    pub fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
 617    where
 618        F: 'static + Send + FnMut(Params, AsyncAppContext),
 619        Params: DeserializeOwned,
 620    {
 621        let prev_handler = self.notification_handlers.lock().insert(
 622            method,
 623            Box::new(move |_, params, cx| {
 624                if let Some(params) = serde_json::from_str(params).log_err() {
 625                    f(params, cx);
 626                }
 627            }),
 628        );
 629        assert!(
 630            prev_handler.is_none(),
 631            "registered multiple handlers for the same LSP method"
 632        );
 633        Subscription::Notification {
 634            method,
 635            notification_handlers: Some(self.notification_handlers.clone()),
 636        }
 637    }
 638
 639    #[must_use]
 640    pub fn on_custom_request<Params, Res, Fut, F>(
 641        &self,
 642        method: &'static str,
 643        mut f: F,
 644    ) -> Subscription
 645    where
 646        F: 'static + Send + FnMut(Params, AsyncAppContext) -> Fut,
 647        Fut: 'static + Future<Output = Result<Res>>,
 648        Params: DeserializeOwned + Send + 'static,
 649        Res: Serialize,
 650    {
 651        let outbound_tx = self.outbound_tx.clone();
 652        let prev_handler = self.notification_handlers.lock().insert(
 653            method,
 654            Box::new(move |id, params, cx| {
 655                if let Some(id) = id {
 656                    match serde_json::from_str(params) {
 657                        Ok(params) => {
 658                            let response = f(params, cx.clone());
 659                            cx.foreground()
 660                                .spawn({
 661                                    let outbound_tx = outbound_tx.clone();
 662                                    async move {
 663                                        let response = match response.await {
 664                                            Ok(result) => Response {
 665                                                jsonrpc: JSON_RPC_VERSION,
 666                                                id,
 667                                                result: Some(result),
 668                                                error: None,
 669                                            },
 670                                            Err(error) => Response {
 671                                                jsonrpc: JSON_RPC_VERSION,
 672                                                id,
 673                                                result: None,
 674                                                error: Some(Error {
 675                                                    message: error.to_string(),
 676                                                }),
 677                                            },
 678                                        };
 679                                        if let Some(response) =
 680                                            serde_json::to_string(&response).log_err()
 681                                        {
 682                                            outbound_tx.try_send(response).ok();
 683                                        }
 684                                    }
 685                                })
 686                                .detach();
 687                        }
 688
 689                        Err(error) => {
 690                            log::error!(
 691                                "error deserializing {} request: {:?}, message: {:?}",
 692                                method,
 693                                error,
 694                                params
 695                            );
 696                            let response = AnyResponse {
 697                                jsonrpc: JSON_RPC_VERSION,
 698                                id,
 699                                result: None,
 700                                error: Some(Error {
 701                                    message: error.to_string(),
 702                                }),
 703                            };
 704                            if let Some(response) = serde_json::to_string(&response).log_err() {
 705                                outbound_tx.try_send(response).ok();
 706                            }
 707                        }
 708                    }
 709                }
 710            }),
 711        );
 712        assert!(
 713            prev_handler.is_none(),
 714            "registered multiple handlers for the same LSP method"
 715        );
 716        Subscription::Notification {
 717            method,
 718            notification_handlers: Some(self.notification_handlers.clone()),
 719        }
 720    }
 721
 722    pub fn name(&self) -> &str {
 723        &self.name
 724    }
 725
 726    pub fn capabilities(&self) -> &ServerCapabilities {
 727        &self.capabilities
 728    }
 729
 730    pub fn server_id(&self) -> LanguageServerId {
 731        self.server_id
 732    }
 733
 734    pub fn root_path(&self) -> &PathBuf {
 735        &self.root_path
 736    }
 737
 738    pub fn request<T: request::Request>(
 739        &self,
 740        params: T::Params,
 741    ) -> impl Future<Output = Result<T::Result>>
 742    where
 743        T::Result: 'static + Send,
 744    {
 745        Self::request_internal::<T>(
 746            &self.next_id,
 747            &self.response_handlers,
 748            &self.outbound_tx,
 749            &self.executor,
 750            params,
 751        )
 752    }
 753
 754    // some child of string literal (be it "" or ``) which is the child of an attribute
 755
 756    // <Foo className="bar" />
 757    // <Foo className={`bar`} />
 758    // <Foo className={something + "bar"} />
 759    // <Foo className={something + "bar"} />
 760    // const classes = "awesome ";
 761    // <Foo className={classes} />
 762
 763    fn request_internal<T: request::Request>(
 764        next_id: &AtomicUsize,
 765        response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
 766        outbound_tx: &channel::Sender<String>,
 767        executor: &Arc<executor::Background>,
 768        params: T::Params,
 769    ) -> impl 'static + Future<Output = anyhow::Result<T::Result>>
 770    where
 771        T::Result: 'static + Send,
 772    {
 773        let id = next_id.fetch_add(1, SeqCst);
 774        let message = serde_json::to_string(&Request {
 775            jsonrpc: JSON_RPC_VERSION,
 776            id,
 777            method: T::METHOD,
 778            params,
 779        })
 780        .unwrap();
 781
 782        let (tx, rx) = oneshot::channel();
 783        let handle_response = response_handlers
 784            .lock()
 785            .as_mut()
 786            .ok_or_else(|| anyhow!("server shut down"))
 787            .map(|handlers| {
 788                let executor = executor.clone();
 789                handlers.insert(
 790                    id,
 791                    Box::new(move |result| {
 792                        executor
 793                            .spawn(async move {
 794                                let response = match result {
 795                                    Ok(response) => serde_json::from_str(&response)
 796                                        .context("failed to deserialize response"),
 797                                    Err(error) => Err(anyhow!("{}", error.message)),
 798                                };
 799                                _ = tx.send(response);
 800                            })
 801                            .detach();
 802                    }),
 803                );
 804            });
 805
 806        let send = outbound_tx
 807            .try_send(message)
 808            .context("failed to write to language server's stdin");
 809
 810        let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse();
 811        let started = Instant::now();
 812        async move {
 813            handle_response?;
 814            send?;
 815
 816            let method = T::METHOD;
 817            futures::select! {
 818                response = rx.fuse() => {
 819                    let elapsed = started.elapsed();
 820                    log::trace!("Took {elapsed:?} to recieve response to {method:?} id {id}");
 821                    response?
 822                }
 823
 824                _ = timeout => {
 825                    log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}");
 826                    anyhow::bail!("LSP request timeout");
 827                }
 828            }
 829        }
 830    }
 831
 832    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
 833        Self::notify_internal::<T>(&self.outbound_tx, params)
 834    }
 835
 836    fn notify_internal<T: notification::Notification>(
 837        outbound_tx: &channel::Sender<String>,
 838        params: T::Params,
 839    ) -> Result<()> {
 840        let message = serde_json::to_string(&Notification {
 841            jsonrpc: JSON_RPC_VERSION,
 842            method: T::METHOD,
 843            params,
 844        })
 845        .unwrap();
 846        outbound_tx.try_send(message)?;
 847        Ok(())
 848    }
 849}
 850
 851impl Drop for LanguageServer {
 852    fn drop(&mut self) {
 853        if let Some(shutdown) = self.shutdown() {
 854            self.executor.spawn(shutdown).detach();
 855        }
 856    }
 857}
 858
 859impl Subscription {
 860    pub fn detach(&mut self) {
 861        match self {
 862            Subscription::Notification {
 863                notification_handlers,
 864                ..
 865            } => *notification_handlers = None,
 866            Subscription::Io { io_handlers, .. } => *io_handlers = None,
 867        }
 868    }
 869}
 870
 871impl fmt::Display for LanguageServerId {
 872    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 873        self.0.fmt(f)
 874    }
 875}
 876
 877impl fmt::Debug for LanguageServer {
 878    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 879        f.debug_struct("LanguageServer")
 880            .field("id", &self.server_id.0)
 881            .field("name", &self.name)
 882            .finish_non_exhaustive()
 883    }
 884}
 885
 886impl Drop for Subscription {
 887    fn drop(&mut self) {
 888        match self {
 889            Subscription::Notification {
 890                method,
 891                notification_handlers,
 892            } => {
 893                if let Some(handlers) = notification_handlers {
 894                    handlers.lock().remove(method);
 895                }
 896            }
 897            Subscription::Io { id, io_handlers } => {
 898                if let Some(io_handlers) = io_handlers.as_ref().and_then(|h| h.upgrade()) {
 899                    io_handlers.lock().remove(id);
 900                }
 901            }
 902        }
 903    }
 904}
 905
 906#[cfg(any(test, feature = "test-support"))]
 907#[derive(Clone)]
 908pub struct FakeLanguageServer {
 909    pub server: Arc<LanguageServer>,
 910    notifications_rx: channel::Receiver<(String, String)>,
 911}
 912
 913#[cfg(any(test, feature = "test-support"))]
 914impl LanguageServer {
 915    pub fn full_capabilities() -> ServerCapabilities {
 916        ServerCapabilities {
 917            document_highlight_provider: Some(OneOf::Left(true)),
 918            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
 919            document_formatting_provider: Some(OneOf::Left(true)),
 920            document_range_formatting_provider: Some(OneOf::Left(true)),
 921            definition_provider: Some(OneOf::Left(true)),
 922            type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)),
 923            ..Default::default()
 924        }
 925    }
 926
 927    pub fn fake(
 928        name: String,
 929        capabilities: ServerCapabilities,
 930        cx: AsyncAppContext,
 931    ) -> (Self, FakeLanguageServer) {
 932        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 933        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 934        let (notifications_tx, notifications_rx) = channel::unbounded();
 935
 936        let server = Self::new_internal(
 937            LanguageServerId(0),
 938            stdin_writer,
 939            stdout_reader,
 940            None::<async_pipe::PipeReader>,
 941            None,
 942            Path::new("/"),
 943            None,
 944            cx.clone(),
 945            |_| {},
 946        );
 947        let fake = FakeLanguageServer {
 948            server: Arc::new(Self::new_internal(
 949                LanguageServerId(0),
 950                stdout_writer,
 951                stdin_reader,
 952                None::<async_pipe::PipeReader>,
 953                None,
 954                Path::new("/"),
 955                None,
 956                cx,
 957                move |msg| {
 958                    notifications_tx
 959                        .try_send((
 960                            msg.method.to_string(),
 961                            msg.params
 962                                .map(|raw_value| raw_value.get())
 963                                .unwrap_or("null")
 964                                .to_string(),
 965                        ))
 966                        .ok();
 967                },
 968            )),
 969            notifications_rx,
 970        };
 971        fake.handle_request::<request::Initialize, _, _>({
 972            let capabilities = capabilities;
 973            move |_, _| {
 974                let capabilities = capabilities.clone();
 975                let name = name.clone();
 976                async move {
 977                    Ok(InitializeResult {
 978                        capabilities,
 979                        server_info: Some(ServerInfo {
 980                            name,
 981                            ..Default::default()
 982                        }),
 983                    })
 984                }
 985            }
 986        });
 987
 988        (server, fake)
 989    }
 990}
 991
 992#[cfg(any(test, feature = "test-support"))]
 993impl FakeLanguageServer {
 994    pub fn notify<T: notification::Notification>(&self, params: T::Params) {
 995        self.server.notify::<T>(params).ok();
 996    }
 997
 998    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
 999    where
1000        T: request::Request,
1001        T::Result: 'static + Send,
1002    {
1003        self.server.executor.start_waiting();
1004        self.server.request::<T>(params).await
1005    }
1006
1007    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
1008        self.server.executor.start_waiting();
1009        self.try_receive_notification::<T>().await.unwrap()
1010    }
1011
1012    pub async fn try_receive_notification<T: notification::Notification>(
1013        &mut self,
1014    ) -> Option<T::Params> {
1015        use futures::StreamExt as _;
1016
1017        loop {
1018            let (method, params) = self.notifications_rx.next().await?;
1019            if method == T::METHOD {
1020                return Some(serde_json::from_str::<T::Params>(&params).unwrap());
1021            } else {
1022                log::info!("skipping message in fake language server {:?}", params);
1023            }
1024        }
1025    }
1026
1027    pub fn handle_request<T, F, Fut>(
1028        &self,
1029        mut handler: F,
1030    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1031    where
1032        T: 'static + request::Request,
1033        T::Params: 'static + Send,
1034        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
1035        Fut: 'static + Send + Future<Output = Result<T::Result>>,
1036    {
1037        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
1038        self.server.remove_request_handler::<T>();
1039        self.server
1040            .on_request::<T, _, _>(move |params, cx| {
1041                let result = handler(params, cx.clone());
1042                let responded_tx = responded_tx.clone();
1043                async move {
1044                    cx.background().simulate_random_delay().await;
1045                    let result = result.await;
1046                    responded_tx.unbounded_send(()).ok();
1047                    result
1048                }
1049            })
1050            .detach();
1051        responded_rx
1052    }
1053
1054    pub fn handle_notification<T, F>(
1055        &self,
1056        mut handler: F,
1057    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1058    where
1059        T: 'static + notification::Notification,
1060        T::Params: 'static + Send,
1061        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext),
1062    {
1063        let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded();
1064        self.server.remove_notification_handler::<T>();
1065        self.server
1066            .on_notification::<T, _>(move |params, cx| {
1067                handler(params, cx.clone());
1068                handled_tx.unbounded_send(()).ok();
1069            })
1070            .detach();
1071        handled_rx
1072    }
1073
1074    pub fn remove_request_handler<T>(&mut self)
1075    where
1076        T: 'static + request::Request,
1077    {
1078        self.server.remove_request_handler::<T>();
1079    }
1080
1081    pub async fn start_progress(&self, token: impl Into<String>) {
1082        let token = token.into();
1083        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
1084            token: NumberOrString::String(token.clone()),
1085        })
1086        .await
1087        .unwrap();
1088        self.notify::<notification::Progress>(ProgressParams {
1089            token: NumberOrString::String(token),
1090            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
1091        });
1092    }
1093
1094    pub fn end_progress(&self, token: impl Into<String>) {
1095        self.notify::<notification::Progress>(ProgressParams {
1096            token: NumberOrString::String(token.into()),
1097            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
1098        });
1099    }
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use super::*;
1105    use gpui::TestAppContext;
1106
1107    #[ctor::ctor]
1108    fn init_logger() {
1109        if std::env::var("RUST_LOG").is_ok() {
1110            env_logger::init();
1111        }
1112    }
1113
1114    #[gpui::test]
1115    async fn test_fake(cx: &mut TestAppContext) {
1116        let (server, mut fake) =
1117            LanguageServer::fake("the-lsp".to_string(), Default::default(), cx.to_async());
1118
1119        let (message_tx, message_rx) = channel::unbounded();
1120        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
1121        server
1122            .on_notification::<notification::ShowMessage, _>(move |params, _| {
1123                message_tx.try_send(params).unwrap()
1124            })
1125            .detach();
1126        server
1127            .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
1128                diagnostics_tx.try_send(params).unwrap()
1129            })
1130            .detach();
1131
1132        let server = server.initialize(None).await.unwrap();
1133        server
1134            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
1135                text_document: TextDocumentItem::new(
1136                    Url::from_str("file://a/b").unwrap(),
1137                    "rust".to_string(),
1138                    0,
1139                    "".to_string(),
1140                ),
1141            })
1142            .unwrap();
1143        assert_eq!(
1144            fake.receive_notification::<notification::DidOpenTextDocument>()
1145                .await
1146                .text_document
1147                .uri
1148                .as_str(),
1149            "file://a/b"
1150        );
1151
1152        fake.notify::<notification::ShowMessage>(ShowMessageParams {
1153            typ: MessageType::ERROR,
1154            message: "ok".to_string(),
1155        });
1156        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
1157            uri: Url::from_str("file://b/c").unwrap(),
1158            version: Some(5),
1159            diagnostics: vec![],
1160        });
1161        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
1162        assert_eq!(
1163            diagnostics_rx.recv().await.unwrap().uri.as_str(),
1164            "file://b/c"
1165        );
1166
1167        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
1168
1169        drop(server);
1170        fake.receive_notification::<notification::Exit>().await;
1171    }
1172}