lsp.rs

   1mod input_handler;
   2
   3pub use lsp_types::request::*;
   4pub use lsp_types::*;
   5
   6use anyhow::{anyhow, Context, Result};
   7use collections::HashMap;
   8use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, Future, FutureExt};
   9use gpui::{AppContext, AsyncAppContext, BackgroundExecutor, SharedString, Task};
  10use parking_lot::{Mutex, RwLock};
  11use postage::{barrier, prelude::Stream};
  12use schemars::{
  13    gen::SchemaGenerator,
  14    schema::{InstanceType, Schema, SchemaObject},
  15    JsonSchema,
  16};
  17use serde::{de::DeserializeOwned, Deserialize, Serialize};
  18use serde_json::{json, value::RawValue, Value};
  19use smol::{
  20    channel,
  21    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
  22    process::{self, Child},
  23};
  24
  25#[cfg(target_os = "windows")]
  26use smol::process::windows::CommandExt;
  27
  28use std::{
  29    ffi::{OsStr, OsString},
  30    fmt,
  31    io::Write,
  32    ops::DerefMut,
  33    path::PathBuf,
  34    pin::Pin,
  35    sync::{
  36        atomic::{AtomicI32, Ordering::SeqCst},
  37        Arc, Weak,
  38    },
  39    task::Poll,
  40    time::{Duration, Instant},
  41};
  42use std::{path::Path, process::Stdio};
  43use util::{ResultExt, TryFutureExt};
  44
  45const JSON_RPC_VERSION: &str = "2.0";
  46const CONTENT_LEN_HEADER: &str = "Content-Length: ";
  47
  48const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  49const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
  50
  51type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, Value, AsyncAppContext)>;
  52type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
  53type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
  54
  55/// Kind of language server stdio given to an IO handler.
  56#[derive(Debug, Clone, Copy)]
  57pub enum IoKind {
  58    StdOut,
  59    StdIn,
  60    StdErr,
  61}
  62
  63/// Represents a launchable language server. This can either be a standalone binary or the path
  64/// to a runtime with arguments to instruct it to launch the actual language server file.
  65#[derive(Debug, Clone, Deserialize)]
  66pub struct LanguageServerBinary {
  67    pub path: PathBuf,
  68    pub arguments: Vec<OsString>,
  69    pub env: Option<HashMap<String, String>>,
  70}
  71
  72/// Configures the search (and installation) of language servers.
  73#[derive(Debug, Clone, Deserialize)]
  74pub struct LanguageServerBinaryOptions {
  75    /// Whether the adapter should look at the users system
  76    pub allow_path_lookup: bool,
  77    /// Whether the adapter should download its own version
  78    pub allow_binary_download: bool,
  79}
  80
  81/// A running language server process.
  82pub struct LanguageServer {
  83    server_id: LanguageServerId,
  84    next_id: AtomicI32,
  85    outbound_tx: channel::Sender<String>,
  86    name: LanguageServerName,
  87    process_name: Arc<str>,
  88    capabilities: RwLock<ServerCapabilities>,
  89    code_action_kinds: Option<Vec<CodeActionKind>>,
  90    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
  91    response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
  92    io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
  93    executor: BackgroundExecutor,
  94    #[allow(clippy::type_complexity)]
  95    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
  96    output_done_rx: Mutex<Option<barrier::Receiver>>,
  97    root_path: PathBuf,
  98    working_dir: PathBuf,
  99    server: Arc<Mutex<Option<Child>>>,
 100}
 101
 102/// Identifies a running language server.
 103#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 104#[repr(transparent)]
 105pub struct LanguageServerId(pub usize);
 106
 107impl LanguageServerId {
 108    pub fn from_proto(id: u64) -> Self {
 109        Self(id as usize)
 110    }
 111
 112    pub fn to_proto(self) -> u64 {
 113        self.0 as u64
 114    }
 115}
 116
 117/// A name of a language server.
 118#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
 119pub struct LanguageServerName(pub SharedString);
 120
 121impl std::fmt::Display for LanguageServerName {
 122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 123        std::fmt::Display::fmt(&self.0, f)
 124    }
 125}
 126
 127impl AsRef<str> for LanguageServerName {
 128    fn as_ref(&self) -> &str {
 129        self.0.as_ref()
 130    }
 131}
 132
 133impl AsRef<OsStr> for LanguageServerName {
 134    fn as_ref(&self) -> &OsStr {
 135        self.0.as_ref().as_ref()
 136    }
 137}
 138
 139impl JsonSchema for LanguageServerName {
 140    fn schema_name() -> String {
 141        "LanguageServerName".into()
 142    }
 143
 144    fn json_schema(_: &mut SchemaGenerator) -> Schema {
 145        SchemaObject {
 146            instance_type: Some(InstanceType::String.into()),
 147            ..Default::default()
 148        }
 149        .into()
 150    }
 151}
 152
 153impl LanguageServerName {
 154    pub const fn new_static(s: &'static str) -> Self {
 155        Self(SharedString::new_static(s))
 156    }
 157
 158    pub fn from_proto(s: String) -> Self {
 159        Self(s.into())
 160    }
 161}
 162
 163impl<'a> From<&'a str> for LanguageServerName {
 164    fn from(str: &'a str) -> LanguageServerName {
 165        LanguageServerName(str.to_string().into())
 166    }
 167}
 168
 169/// Handle to a language server RPC activity subscription.
 170pub enum Subscription {
 171    Notification {
 172        method: &'static str,
 173        notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>,
 174    },
 175    Io {
 176        id: i32,
 177        io_handlers: Option<Weak<Mutex<HashMap<i32, IoHandler>>>>,
 178    },
 179}
 180
 181/// Language server protocol RPC request message ID.
 182///
 183/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 184#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
 185#[serde(untagged)]
 186pub enum RequestId {
 187    Int(i32),
 188    Str(String),
 189}
 190
 191/// Language server protocol RPC request message.
 192///
 193/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 194#[derive(Serialize, Deserialize)]
 195pub struct Request<'a, T> {
 196    jsonrpc: &'static str,
 197    id: RequestId,
 198    method: &'a str,
 199    params: T,
 200}
 201
 202/// Language server protocol RPC request response message before it is deserialized into a concrete type.
 203#[derive(Serialize, Deserialize)]
 204struct AnyResponse<'a> {
 205    jsonrpc: &'a str,
 206    id: RequestId,
 207    #[serde(default)]
 208    error: Option<Error>,
 209    #[serde(borrow)]
 210    result: Option<&'a RawValue>,
 211}
 212
 213/// Language server protocol RPC request response message.
 214///
 215/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#responseMessage)
 216#[derive(Serialize)]
 217struct Response<T> {
 218    jsonrpc: &'static str,
 219    id: RequestId,
 220    #[serde(flatten)]
 221    value: LspResult<T>,
 222}
 223
 224#[derive(Serialize)]
 225#[serde(rename_all = "snake_case")]
 226enum LspResult<T> {
 227    #[serde(rename = "result")]
 228    Ok(Option<T>),
 229    Error(Option<Error>),
 230}
 231
 232/// Language server protocol RPC notification message.
 233///
 234/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
 235#[derive(Serialize, Deserialize)]
 236struct Notification<'a, T> {
 237    jsonrpc: &'static str,
 238    #[serde(borrow)]
 239    method: &'a str,
 240    params: T,
 241}
 242
 243/// Language server RPC notification message before it is deserialized into a concrete type.
 244#[derive(Debug, Clone, Deserialize)]
 245struct AnyNotification {
 246    #[serde(default)]
 247    id: Option<RequestId>,
 248    method: String,
 249    #[serde(default)]
 250    params: Option<Value>,
 251}
 252
 253#[derive(Debug, Serialize, Deserialize)]
 254struct Error {
 255    message: String,
 256}
 257
 258pub trait LspRequestFuture<O>: Future<Output = O> {
 259    fn id(&self) -> i32;
 260}
 261
 262struct LspRequest<F> {
 263    id: i32,
 264    request: F,
 265}
 266
 267impl<F> LspRequest<F> {
 268    pub fn new(id: i32, request: F) -> Self {
 269        Self { id, request }
 270    }
 271}
 272
 273impl<F: Future> Future for LspRequest<F> {
 274    type Output = F::Output;
 275
 276    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
 277        // SAFETY: This is standard pin projection, we're pinned so our fields must be pinned.
 278        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().request) };
 279        inner.poll(cx)
 280    }
 281}
 282
 283impl<F: Future> LspRequestFuture<F::Output> for LspRequest<F> {
 284    fn id(&self) -> i32 {
 285        self.id
 286    }
 287}
 288
 289/// Combined capabilities of the server and the adapter.
 290pub struct AdapterServerCapabilities {
 291    // Reported capabilities by the server
 292    pub server_capabilities: ServerCapabilities,
 293    // List of code actions supported by the LspAdapter matching the server
 294    pub code_action_kinds: Option<Vec<CodeActionKind>>,
 295}
 296
 297/// Experimental: Informs the end user about the state of the server
 298///
 299/// [Rust Analyzer Specification](https://github.com/rust-lang/rust-analyzer/blob/master/docs/dev/lsp-extensions.md#server-status)
 300#[derive(Debug)]
 301pub enum ServerStatus {}
 302
 303/// Other(String) variant to handle unknown values due to this still being experimental
 304#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
 305#[serde(rename_all = "camelCase")]
 306pub enum ServerHealthStatus {
 307    Ok,
 308    Warning,
 309    Error,
 310    Other(String),
 311}
 312
 313#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
 314#[serde(rename_all = "camelCase")]
 315pub struct ServerStatusParams {
 316    pub health: ServerHealthStatus,
 317    pub message: Option<String>,
 318}
 319
 320impl lsp_types::notification::Notification for ServerStatus {
 321    type Params = ServerStatusParams;
 322    const METHOD: &'static str = "experimental/serverStatus";
 323}
 324
 325impl LanguageServer {
 326    /// Starts a language server process.
 327    pub fn new(
 328        stderr_capture: Arc<Mutex<Option<String>>>,
 329        server_id: LanguageServerId,
 330        server_name: LanguageServerName,
 331        binary: LanguageServerBinary,
 332        root_path: &Path,
 333        code_action_kinds: Option<Vec<CodeActionKind>>,
 334        cx: AsyncAppContext,
 335    ) -> Result<Self> {
 336        let working_dir = if root_path.is_dir() {
 337            root_path
 338        } else {
 339            root_path.parent().unwrap_or_else(|| Path::new("/"))
 340        };
 341
 342        log::info!(
 343            "starting language server process. binary path: {:?}, working directory: {:?}, args: {:?}",
 344            binary.path,
 345            working_dir,
 346            &binary.arguments
 347        );
 348
 349        let mut command = process::Command::new(&binary.path);
 350        command
 351            .current_dir(working_dir)
 352            .args(&binary.arguments)
 353            .envs(binary.env.unwrap_or_default())
 354            .stdin(Stdio::piped())
 355            .stdout(Stdio::piped())
 356            .stderr(Stdio::piped())
 357            .kill_on_drop(true);
 358        #[cfg(windows)]
 359        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 360        let mut server = command.spawn().with_context(|| {
 361            format!(
 362                "failed to spawn command. path: {:?}, working directory: {:?}, args: {:?}",
 363                binary.path, working_dir, &binary.arguments
 364            )
 365        })?;
 366
 367        let stdin = server.stdin.take().unwrap();
 368        let stdout = server.stdout.take().unwrap();
 369        let stderr = server.stderr.take().unwrap();
 370        let mut server = Self::new_internal(
 371            server_id,
 372            server_name,
 373            stdin,
 374            stdout,
 375            Some(stderr),
 376            stderr_capture,
 377            Some(server),
 378            root_path,
 379            working_dir,
 380            code_action_kinds,
 381            cx,
 382            move |notification| {
 383                log::info!(
 384                    "Language server with id {} sent unhandled notification {}:\n{}",
 385                    server_id,
 386                    notification.method,
 387                    serde_json::to_string_pretty(&notification.params).unwrap(),
 388                );
 389            },
 390        );
 391
 392        if let Some(name) = binary.path.file_name() {
 393            server.process_name = name.to_string_lossy().into();
 394        }
 395
 396        Ok(server)
 397    }
 398
 399    #[allow(clippy::too_many_arguments)]
 400    fn new_internal<Stdin, Stdout, Stderr, F>(
 401        server_id: LanguageServerId,
 402        server_name: LanguageServerName,
 403        stdin: Stdin,
 404        stdout: Stdout,
 405        stderr: Option<Stderr>,
 406        stderr_capture: Arc<Mutex<Option<String>>>,
 407        server: Option<Child>,
 408        root_path: &Path,
 409        working_dir: &Path,
 410        code_action_kinds: Option<Vec<CodeActionKind>>,
 411        cx: AsyncAppContext,
 412        on_unhandled_notification: F,
 413    ) -> Self
 414    where
 415        Stdin: AsyncWrite + Unpin + Send + 'static,
 416        Stdout: AsyncRead + Unpin + Send + 'static,
 417        Stderr: AsyncRead + Unpin + Send + 'static,
 418        F: FnMut(AnyNotification) + 'static + Send + Sync + Clone,
 419    {
 420        let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
 421        let (output_done_tx, output_done_rx) = barrier::channel();
 422        let notification_handlers =
 423            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
 424        let response_handlers =
 425            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
 426        let io_handlers = Arc::new(Mutex::new(HashMap::default()));
 427
 428        let stdout_input_task = cx.spawn({
 429            let on_unhandled_notification = on_unhandled_notification.clone();
 430            let notification_handlers = notification_handlers.clone();
 431            let response_handlers = response_handlers.clone();
 432            let io_handlers = io_handlers.clone();
 433            move |cx| {
 434                Self::handle_input(
 435                    stdout,
 436                    on_unhandled_notification,
 437                    notification_handlers,
 438                    response_handlers,
 439                    io_handlers,
 440                    cx,
 441                )
 442                .log_err()
 443            }
 444        });
 445        let stderr_input_task = stderr
 446            .map(|stderr| {
 447                let io_handlers = io_handlers.clone();
 448                let stderr_captures = stderr_capture.clone();
 449                cx.spawn(|_| Self::handle_stderr(stderr, io_handlers, stderr_captures).log_err())
 450            })
 451            .unwrap_or_else(|| Task::Ready(Some(None)));
 452        let input_task = cx.spawn(|_| async move {
 453            let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
 454            stdout.or(stderr)
 455        });
 456        let output_task = cx.background_executor().spawn({
 457            Self::handle_output(
 458                stdin,
 459                outbound_rx,
 460                output_done_tx,
 461                response_handlers.clone(),
 462                io_handlers.clone(),
 463            )
 464            .log_err()
 465        });
 466
 467        Self {
 468            server_id,
 469            notification_handlers,
 470            response_handlers,
 471            io_handlers,
 472            name: server_name,
 473            process_name: Arc::default(),
 474            capabilities: Default::default(),
 475            code_action_kinds,
 476            next_id: Default::default(),
 477            outbound_tx,
 478            executor: cx.background_executor().clone(),
 479            io_tasks: Mutex::new(Some((input_task, output_task))),
 480            output_done_rx: Mutex::new(Some(output_done_rx)),
 481            root_path: root_path.to_path_buf(),
 482            working_dir: working_dir.to_path_buf(),
 483            server: Arc::new(Mutex::new(server)),
 484        }
 485    }
 486
 487    /// List of code action kinds this language server reports being able to emit.
 488    pub fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
 489        self.code_action_kinds.clone()
 490    }
 491
 492    async fn handle_input<Stdout, F>(
 493        stdout: Stdout,
 494        mut on_unhandled_notification: F,
 495        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 496        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 497        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 498        cx: AsyncAppContext,
 499    ) -> anyhow::Result<()>
 500    where
 501        Stdout: AsyncRead + Unpin + Send + 'static,
 502        F: FnMut(AnyNotification) + 'static + Send,
 503    {
 504        use smol::stream::StreamExt;
 505        let stdout = BufReader::new(stdout);
 506        let _clear_response_handlers = util::defer({
 507            let response_handlers = response_handlers.clone();
 508            move || {
 509                response_handlers.lock().take();
 510            }
 511        });
 512        let mut input_handler = input_handler::LspStdoutHandler::new(
 513            stdout,
 514            response_handlers,
 515            io_handlers,
 516            cx.background_executor().clone(),
 517        );
 518
 519        while let Some(msg) = input_handler.notifications_channel.next().await {
 520            {
 521                let mut notification_handlers = notification_handlers.lock();
 522                if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) {
 523                    handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone());
 524                } else {
 525                    drop(notification_handlers);
 526                    on_unhandled_notification(msg);
 527                }
 528            }
 529
 530            // Don't starve the main thread when receiving lots of notifications at once.
 531            smol::future::yield_now().await;
 532        }
 533        input_handler.loop_handle.await
 534    }
 535
 536    async fn handle_stderr<Stderr>(
 537        stderr: Stderr,
 538        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 539        stderr_capture: Arc<Mutex<Option<String>>>,
 540    ) -> anyhow::Result<()>
 541    where
 542        Stderr: AsyncRead + Unpin + Send + 'static,
 543    {
 544        let mut stderr = BufReader::new(stderr);
 545        let mut buffer = Vec::new();
 546
 547        loop {
 548            buffer.clear();
 549
 550            let bytes_read = stderr.read_until(b'\n', &mut buffer).await?;
 551            if bytes_read == 0 {
 552                return Ok(());
 553            }
 554
 555            if let Ok(message) = std::str::from_utf8(&buffer) {
 556                log::trace!("incoming stderr message:{message}");
 557                for handler in io_handlers.lock().values_mut() {
 558                    handler(IoKind::StdErr, message);
 559                }
 560
 561                if let Some(stderr) = stderr_capture.lock().as_mut() {
 562                    stderr.push_str(message);
 563                }
 564            }
 565
 566            // Don't starve the main thread when receiving lots of messages at once.
 567            smol::future::yield_now().await;
 568        }
 569    }
 570
 571    async fn handle_output<Stdin>(
 572        stdin: Stdin,
 573        outbound_rx: channel::Receiver<String>,
 574        output_done_tx: barrier::Sender,
 575        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 576        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 577    ) -> anyhow::Result<()>
 578    where
 579        Stdin: AsyncWrite + Unpin + Send + 'static,
 580    {
 581        let mut stdin = BufWriter::new(stdin);
 582        let _clear_response_handlers = util::defer({
 583            let response_handlers = response_handlers.clone();
 584            move || {
 585                response_handlers.lock().take();
 586            }
 587        });
 588        let mut content_len_buffer = Vec::new();
 589        while let Ok(message) = outbound_rx.recv().await {
 590            log::trace!("outgoing message:{}", message);
 591            for handler in io_handlers.lock().values_mut() {
 592                handler(IoKind::StdIn, &message);
 593            }
 594
 595            content_len_buffer.clear();
 596            write!(content_len_buffer, "{}", message.len()).unwrap();
 597            stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
 598            stdin.write_all(&content_len_buffer).await?;
 599            stdin.write_all("\r\n\r\n".as_bytes()).await?;
 600            stdin.write_all(message.as_bytes()).await?;
 601            stdin.flush().await?;
 602        }
 603        drop(output_done_tx);
 604        Ok(())
 605    }
 606
 607    /// Initializes a language server by sending the `Initialize` request.
 608    /// Note that `options` is used directly to construct [`InitializeParams`], which is why it is owned.
 609    ///
 610    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#initialize)
 611    pub fn initialize(
 612        mut self,
 613        options: Option<Value>,
 614        cx: &AppContext,
 615    ) -> Task<Result<Arc<Self>>> {
 616        let root_uri = Url::from_file_path(&self.working_dir).unwrap();
 617        #[allow(deprecated)]
 618        let params = InitializeParams {
 619            process_id: None,
 620            root_path: None,
 621            root_uri: Some(root_uri.clone()),
 622            initialization_options: options,
 623            capabilities: ClientCapabilities {
 624                workspace: Some(WorkspaceClientCapabilities {
 625                    configuration: Some(true),
 626                    did_change_watched_files: Some(DidChangeWatchedFilesClientCapabilities {
 627                        dynamic_registration: Some(true),
 628                        relative_pattern_support: Some(true),
 629                    }),
 630                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
 631                        dynamic_registration: Some(true),
 632                    }),
 633                    workspace_folders: Some(true),
 634                    symbol: Some(WorkspaceSymbolClientCapabilities {
 635                        resolve_support: None,
 636                        ..WorkspaceSymbolClientCapabilities::default()
 637                    }),
 638                    inlay_hint: Some(InlayHintWorkspaceClientCapabilities {
 639                        refresh_support: Some(true),
 640                    }),
 641                    diagnostic: Some(DiagnosticWorkspaceClientCapabilities {
 642                        refresh_support: None,
 643                    }),
 644                    workspace_edit: Some(WorkspaceEditClientCapabilities {
 645                        resource_operations: Some(vec![
 646                            ResourceOperationKind::Create,
 647                            ResourceOperationKind::Rename,
 648                            ResourceOperationKind::Delete,
 649                        ]),
 650                        document_changes: Some(true),
 651                        snippet_edit_support: Some(true),
 652                        ..WorkspaceEditClientCapabilities::default()
 653                    }),
 654                    ..Default::default()
 655                }),
 656                text_document: Some(TextDocumentClientCapabilities {
 657                    definition: Some(GotoCapability {
 658                        link_support: Some(true),
 659                        dynamic_registration: None,
 660                    }),
 661                    code_action: Some(CodeActionClientCapabilities {
 662                        code_action_literal_support: Some(CodeActionLiteralSupport {
 663                            code_action_kind: CodeActionKindLiteralSupport {
 664                                value_set: vec![
 665                                    CodeActionKind::REFACTOR.as_str().into(),
 666                                    CodeActionKind::QUICKFIX.as_str().into(),
 667                                    CodeActionKind::SOURCE.as_str().into(),
 668                                ],
 669                            },
 670                        }),
 671                        data_support: Some(true),
 672                        resolve_support: Some(CodeActionCapabilityResolveSupport {
 673                            properties: vec![
 674                                "kind".to_string(),
 675                                "diagnostics".to_string(),
 676                                "isPreferred".to_string(),
 677                                "disabled".to_string(),
 678                                "edit".to_string(),
 679                                "command".to_string(),
 680                            ],
 681                        }),
 682                        ..Default::default()
 683                    }),
 684                    completion: Some(CompletionClientCapabilities {
 685                        completion_item: Some(CompletionItemCapability {
 686                            snippet_support: Some(true),
 687                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
 688                                properties: vec![
 689                                    "additionalTextEdits".to_string(),
 690                                    "command".to_string(),
 691                                    "documentation".to_string(),
 692                                    // NB: Do not have this resolved, otherwise Zed becomes slow to complete things
 693                                    // "textEdit".to_string(),
 694                                ],
 695                            }),
 696                            insert_replace_support: Some(true),
 697                            label_details_support: Some(true),
 698                            ..Default::default()
 699                        }),
 700                        completion_list: Some(CompletionListCapability {
 701                            item_defaults: Some(vec![
 702                                "commitCharacters".to_owned(),
 703                                "editRange".to_owned(),
 704                                "insertTextMode".to_owned(),
 705                                "data".to_owned(),
 706                            ]),
 707                        }),
 708                        context_support: Some(true),
 709                        ..Default::default()
 710                    }),
 711                    rename: Some(RenameClientCapabilities {
 712                        prepare_support: Some(true),
 713                        ..Default::default()
 714                    }),
 715                    hover: Some(HoverClientCapabilities {
 716                        content_format: Some(vec![MarkupKind::Markdown]),
 717                        dynamic_registration: None,
 718                    }),
 719                    inlay_hint: Some(InlayHintClientCapabilities {
 720                        resolve_support: Some(InlayHintResolveClientCapabilities {
 721                            properties: vec![
 722                                "textEdits".to_string(),
 723                                "tooltip".to_string(),
 724                                "label.tooltip".to_string(),
 725                                "label.location".to_string(),
 726                                "label.command".to_string(),
 727                            ],
 728                        }),
 729                        dynamic_registration: Some(false),
 730                    }),
 731                    publish_diagnostics: Some(PublishDiagnosticsClientCapabilities {
 732                        related_information: Some(true),
 733                        ..Default::default()
 734                    }),
 735                    formatting: Some(DynamicRegistrationClientCapabilities {
 736                        dynamic_registration: Some(true),
 737                    }),
 738                    range_formatting: Some(DynamicRegistrationClientCapabilities {
 739                        dynamic_registration: Some(true),
 740                    }),
 741                    on_type_formatting: Some(DynamicRegistrationClientCapabilities {
 742                        dynamic_registration: Some(true),
 743                    }),
 744                    signature_help: Some(SignatureHelpClientCapabilities {
 745                        signature_information: Some(SignatureInformationSettings {
 746                            documentation_format: Some(vec![
 747                                MarkupKind::Markdown,
 748                                MarkupKind::PlainText,
 749                            ]),
 750                            parameter_information: Some(ParameterInformationSettings {
 751                                label_offset_support: Some(true),
 752                            }),
 753                            active_parameter_support: Some(true),
 754                        }),
 755                        ..SignatureHelpClientCapabilities::default()
 756                    }),
 757                    synchronization: Some(TextDocumentSyncClientCapabilities {
 758                        did_save: Some(true),
 759                        ..TextDocumentSyncClientCapabilities::default()
 760                    }),
 761                    ..TextDocumentClientCapabilities::default()
 762                }),
 763                experimental: Some(json!({
 764                    "serverStatusNotification": true,
 765                    "localDocs": true,
 766                })),
 767                window: Some(WindowClientCapabilities {
 768                    work_done_progress: Some(true),
 769                    ..Default::default()
 770                }),
 771                general: None,
 772            },
 773            trace: None,
 774            workspace_folders: Some(vec![WorkspaceFolder {
 775                uri: root_uri,
 776                name: Default::default(),
 777            }]),
 778            client_info: release_channel::ReleaseChannel::try_global(cx).map(|release_channel| {
 779                ClientInfo {
 780                    name: release_channel.display_name().to_string(),
 781                    version: Some(release_channel::AppVersion::global(cx).to_string()),
 782                }
 783            }),
 784            locale: None,
 785            ..Default::default()
 786        };
 787
 788        cx.spawn(|_| async move {
 789            let response = self.request::<request::Initialize>(params).await?;
 790            if let Some(info) = response.server_info {
 791                self.process_name = info.name.into();
 792            }
 793            self.capabilities = RwLock::new(response.capabilities);
 794
 795            self.notify::<notification::Initialized>(InitializedParams {})?;
 796            Ok(Arc::new(self))
 797        })
 798    }
 799
 800    /// Sends a shutdown request to the language server process and prepares the [`LanguageServer`] to be dropped.
 801    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
 802        if let Some(tasks) = self.io_tasks.lock().take() {
 803            let response_handlers = self.response_handlers.clone();
 804            let next_id = AtomicI32::new(self.next_id.load(SeqCst));
 805            let outbound_tx = self.outbound_tx.clone();
 806            let executor = self.executor.clone();
 807            let mut output_done = self.output_done_rx.lock().take().unwrap();
 808            let shutdown_request = Self::request_internal::<request::Shutdown>(
 809                &next_id,
 810                &response_handlers,
 811                &outbound_tx,
 812                &executor,
 813                (),
 814            );
 815            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
 816            outbound_tx.close();
 817
 818            let server = self.server.clone();
 819            let name = self.name.clone();
 820            let mut timer = self.executor.timer(SERVER_SHUTDOWN_TIMEOUT).fuse();
 821            Some(
 822                async move {
 823                    log::debug!("language server shutdown started");
 824
 825                    select! {
 826                        request_result = shutdown_request.fuse() => {
 827                            request_result?;
 828                        }
 829
 830                        _ = timer => {
 831                            log::info!("timeout waiting for language server {name} to shutdown");
 832                        },
 833                    }
 834
 835                    response_handlers.lock().take();
 836                    exit?;
 837                    output_done.recv().await;
 838                    server.lock().take().map(|mut child| child.kill());
 839                    log::debug!("language server shutdown finished");
 840
 841                    drop(tasks);
 842                    anyhow::Ok(())
 843                }
 844                .log_err(),
 845            )
 846        } else {
 847            None
 848        }
 849    }
 850
 851    /// Register a handler to handle incoming LSP notifications.
 852    ///
 853    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
 854    #[must_use]
 855    pub fn on_notification<T, F>(&self, f: F) -> Subscription
 856    where
 857        T: notification::Notification,
 858        F: 'static + Send + FnMut(T::Params, AsyncAppContext),
 859    {
 860        self.on_custom_notification(T::METHOD, f)
 861    }
 862
 863    /// Register a handler to handle incoming LSP requests.
 864    ///
 865    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 866    #[must_use]
 867    pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
 868    where
 869        T: request::Request,
 870        T::Params: 'static + Send,
 871        F: 'static + FnMut(T::Params, AsyncAppContext) -> Fut + Send,
 872        Fut: 'static + Future<Output = Result<T::Result>>,
 873    {
 874        self.on_custom_request(T::METHOD, f)
 875    }
 876
 877    /// Registers a handler to inspect all language server process stdio.
 878    #[must_use]
 879    pub fn on_io<F>(&self, f: F) -> Subscription
 880    where
 881        F: 'static + Send + FnMut(IoKind, &str),
 882    {
 883        let id = self.next_id.fetch_add(1, SeqCst);
 884        self.io_handlers.lock().insert(id, Box::new(f));
 885        Subscription::Io {
 886            id,
 887            io_handlers: Some(Arc::downgrade(&self.io_handlers)),
 888        }
 889    }
 890
 891    /// Removes a request handler registers via [`Self::on_request`].
 892    pub fn remove_request_handler<T: request::Request>(&self) {
 893        self.notification_handlers.lock().remove(T::METHOD);
 894    }
 895
 896    /// Removes a notification handler registers via [`Self::on_notification`].
 897    pub fn remove_notification_handler<T: notification::Notification>(&self) {
 898        self.notification_handlers.lock().remove(T::METHOD);
 899    }
 900
 901    /// Checks if a notification handler has been registered via [`Self::on_notification`].
 902    pub fn has_notification_handler<T: notification::Notification>(&self) -> bool {
 903        self.notification_handlers.lock().contains_key(T::METHOD)
 904    }
 905
 906    #[must_use]
 907    fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
 908    where
 909        F: 'static + FnMut(Params, AsyncAppContext) + Send,
 910        Params: DeserializeOwned,
 911    {
 912        let prev_handler = self.notification_handlers.lock().insert(
 913            method,
 914            Box::new(move |_, params, cx| {
 915                if let Some(params) = serde_json::from_value(params).log_err() {
 916                    f(params, cx);
 917                }
 918            }),
 919        );
 920        assert!(
 921            prev_handler.is_none(),
 922            "registered multiple handlers for the same LSP method"
 923        );
 924        Subscription::Notification {
 925            method,
 926            notification_handlers: Some(self.notification_handlers.clone()),
 927        }
 928    }
 929
 930    #[must_use]
 931    fn on_custom_request<Params, Res, Fut, F>(&self, method: &'static str, mut f: F) -> Subscription
 932    where
 933        F: 'static + FnMut(Params, AsyncAppContext) -> Fut + Send,
 934        Fut: 'static + Future<Output = Result<Res>>,
 935        Params: DeserializeOwned + Send + 'static,
 936        Res: Serialize,
 937    {
 938        let outbound_tx = self.outbound_tx.clone();
 939        let prev_handler = self.notification_handlers.lock().insert(
 940            method,
 941            Box::new(move |id, params, cx| {
 942                if let Some(id) = id {
 943                    match serde_json::from_value(params) {
 944                        Ok(params) => {
 945                            let response = f(params, cx.clone());
 946                            cx.foreground_executor()
 947                                .spawn({
 948                                    let outbound_tx = outbound_tx.clone();
 949                                    async move {
 950                                        let response = match response.await {
 951                                            Ok(result) => Response {
 952                                                jsonrpc: JSON_RPC_VERSION,
 953                                                id,
 954                                                value: LspResult::Ok(Some(result)),
 955                                            },
 956                                            Err(error) => Response {
 957                                                jsonrpc: JSON_RPC_VERSION,
 958                                                id,
 959                                                value: LspResult::Error(Some(Error {
 960                                                    message: error.to_string(),
 961                                                })),
 962                                            },
 963                                        };
 964                                        if let Some(response) =
 965                                            serde_json::to_string(&response).log_err()
 966                                        {
 967                                            outbound_tx.try_send(response).ok();
 968                                        }
 969                                    }
 970                                })
 971                                .detach();
 972                        }
 973
 974                        Err(error) => {
 975                            log::error!("error deserializing {} request: {:?}", method, error);
 976                            let response = AnyResponse {
 977                                jsonrpc: JSON_RPC_VERSION,
 978                                id,
 979                                result: None,
 980                                error: Some(Error {
 981                                    message: error.to_string(),
 982                                }),
 983                            };
 984                            if let Some(response) = serde_json::to_string(&response).log_err() {
 985                                outbound_tx.try_send(response).ok();
 986                            }
 987                        }
 988                    }
 989                }
 990            }),
 991        );
 992        assert!(
 993            prev_handler.is_none(),
 994            "registered multiple handlers for the same LSP method"
 995        );
 996        Subscription::Notification {
 997            method,
 998            notification_handlers: Some(self.notification_handlers.clone()),
 999        }
1000    }
1001
1002    /// Get the name of the running language server.
1003    pub fn name(&self) -> LanguageServerName {
1004        self.name.clone()
1005    }
1006
1007    pub fn process_name(&self) -> &str {
1008        &self.process_name
1009    }
1010
1011    /// Get the reported capabilities of the running language server.
1012    pub fn capabilities(&self) -> ServerCapabilities {
1013        self.capabilities.read().clone()
1014    }
1015
1016    /// Get the reported capabilities of the running language server and
1017    /// what we know on the client/adapter-side of its capabilities.
1018    pub fn adapter_server_capabilities(&self) -> AdapterServerCapabilities {
1019        AdapterServerCapabilities {
1020            server_capabilities: self.capabilities(),
1021            code_action_kinds: self.code_action_kinds(),
1022        }
1023    }
1024
1025    pub fn update_capabilities(&self, update: impl FnOnce(&mut ServerCapabilities)) {
1026        update(self.capabilities.write().deref_mut());
1027    }
1028
1029    /// Get the id of the running language server.
1030    pub fn server_id(&self) -> LanguageServerId {
1031        self.server_id
1032    }
1033
1034    /// Get the root path of the project the language server is running against.
1035    pub fn root_path(&self) -> &PathBuf {
1036        &self.root_path
1037    }
1038
1039    /// Sends a RPC request to the language server.
1040    ///
1041    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
1042    pub fn request<T: request::Request>(
1043        &self,
1044        params: T::Params,
1045    ) -> impl LspRequestFuture<Result<T::Result>>
1046    where
1047        T::Result: 'static + Send,
1048    {
1049        Self::request_internal::<T>(
1050            &self.next_id,
1051            &self.response_handlers,
1052            &self.outbound_tx,
1053            &self.executor,
1054            params,
1055        )
1056    }
1057
1058    fn request_internal<T: request::Request>(
1059        next_id: &AtomicI32,
1060        response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>,
1061        outbound_tx: &channel::Sender<String>,
1062        executor: &BackgroundExecutor,
1063        params: T::Params,
1064    ) -> impl LspRequestFuture<Result<T::Result>>
1065    where
1066        T::Result: 'static + Send,
1067    {
1068        let id = next_id.fetch_add(1, SeqCst);
1069        let message = serde_json::to_string(&Request {
1070            jsonrpc: JSON_RPC_VERSION,
1071            id: RequestId::Int(id),
1072            method: T::METHOD,
1073            params,
1074        })
1075        .unwrap();
1076
1077        let (tx, rx) = oneshot::channel();
1078        let handle_response = response_handlers
1079            .lock()
1080            .as_mut()
1081            .ok_or_else(|| anyhow!("server shut down"))
1082            .map(|handlers| {
1083                let executor = executor.clone();
1084                handlers.insert(
1085                    RequestId::Int(id),
1086                    Box::new(move |result| {
1087                        executor
1088                            .spawn(async move {
1089                                let response = match result {
1090                                    Ok(response) => match serde_json::from_str(&response) {
1091                                        Ok(deserialized) => Ok(deserialized),
1092                                        Err(error) => {
1093                                            log::error!("failed to deserialize response from language server: {}. response from language server: {:?}", error, response);
1094                                            Err(error).context("failed to deserialize response")
1095                                        }
1096                                    }
1097                                    Err(error) => Err(anyhow!("{}", error.message)),
1098                                };
1099                                _ = tx.send(response);
1100                            })
1101                            .detach();
1102                    }),
1103                );
1104            });
1105
1106        let send = outbound_tx
1107            .try_send(message)
1108            .context("failed to write to language server's stdin");
1109
1110        let outbound_tx = outbound_tx.downgrade();
1111        let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse();
1112        let started = Instant::now();
1113        LspRequest::new(id, async move {
1114            handle_response?;
1115            send?;
1116
1117            let cancel_on_drop = util::defer(move || {
1118                if let Some(outbound_tx) = outbound_tx.upgrade() {
1119                    Self::notify_internal::<notification::Cancel>(
1120                        &outbound_tx,
1121                        CancelParams {
1122                            id: NumberOrString::Number(id),
1123                        },
1124                    )
1125                    .log_err();
1126                }
1127            });
1128
1129            let method = T::METHOD;
1130            select! {
1131                response = rx.fuse() => {
1132                    let elapsed = started.elapsed();
1133                    log::trace!("Took {elapsed:?} to receive response to {method:?} id {id}");
1134                    cancel_on_drop.abort();
1135                    response?
1136                }
1137
1138                _ = timeout => {
1139                    log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}");
1140                    anyhow::bail!("LSP request timeout");
1141                }
1142            }
1143        })
1144    }
1145
1146    /// Sends a RPC notification to the language server.
1147    ///
1148    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
1149    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
1150        Self::notify_internal::<T>(&self.outbound_tx, params)
1151    }
1152
1153    fn notify_internal<T: notification::Notification>(
1154        outbound_tx: &channel::Sender<String>,
1155        params: T::Params,
1156    ) -> Result<()> {
1157        let message = serde_json::to_string(&Notification {
1158            jsonrpc: JSON_RPC_VERSION,
1159            method: T::METHOD,
1160            params,
1161        })
1162        .unwrap();
1163        outbound_tx.try_send(message)?;
1164        Ok(())
1165    }
1166}
1167
1168impl Drop for LanguageServer {
1169    fn drop(&mut self) {
1170        if let Some(shutdown) = self.shutdown() {
1171            self.executor.spawn(shutdown).detach();
1172        }
1173    }
1174}
1175
1176impl Subscription {
1177    /// Detaching a subscription handle prevents it from unsubscribing on drop.
1178    pub fn detach(&mut self) {
1179        match self {
1180            Subscription::Notification {
1181                notification_handlers,
1182                ..
1183            } => *notification_handlers = None,
1184            Subscription::Io { io_handlers, .. } => *io_handlers = None,
1185        }
1186    }
1187}
1188
1189impl fmt::Display for LanguageServerId {
1190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1191        self.0.fmt(f)
1192    }
1193}
1194
1195impl fmt::Debug for LanguageServer {
1196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1197        f.debug_struct("LanguageServer")
1198            .field("id", &self.server_id.0)
1199            .field("name", &self.name)
1200            .finish_non_exhaustive()
1201    }
1202}
1203
1204impl Drop for Subscription {
1205    fn drop(&mut self) {
1206        match self {
1207            Subscription::Notification {
1208                method,
1209                notification_handlers,
1210            } => {
1211                if let Some(handlers) = notification_handlers {
1212                    handlers.lock().remove(method);
1213                }
1214            }
1215            Subscription::Io { id, io_handlers } => {
1216                if let Some(io_handlers) = io_handlers.as_ref().and_then(|h| h.upgrade()) {
1217                    io_handlers.lock().remove(id);
1218                }
1219            }
1220        }
1221    }
1222}
1223
1224/// Mock language server for use in tests.
1225#[cfg(any(test, feature = "test-support"))]
1226#[derive(Clone)]
1227pub struct FakeLanguageServer {
1228    pub binary: LanguageServerBinary,
1229    pub server: Arc<LanguageServer>,
1230    notifications_rx: channel::Receiver<(String, String)>,
1231}
1232
1233#[cfg(any(test, feature = "test-support"))]
1234impl FakeLanguageServer {
1235    /// Construct a fake language server.
1236    pub fn new(
1237        server_id: LanguageServerId,
1238        binary: LanguageServerBinary,
1239        name: String,
1240        capabilities: ServerCapabilities,
1241        cx: AsyncAppContext,
1242    ) -> (LanguageServer, FakeLanguageServer) {
1243        let (stdin_writer, stdin_reader) = async_pipe::pipe();
1244        let (stdout_writer, stdout_reader) = async_pipe::pipe();
1245        let (notifications_tx, notifications_rx) = channel::unbounded();
1246
1247        let root = Self::root_path();
1248
1249        let server_name = LanguageServerName(name.clone().into());
1250        let process_name = Arc::from(name.as_str());
1251        let mut server = LanguageServer::new_internal(
1252            server_id,
1253            server_name.clone(),
1254            stdin_writer,
1255            stdout_reader,
1256            None::<async_pipe::PipeReader>,
1257            Arc::new(Mutex::new(None)),
1258            None,
1259            root,
1260            root,
1261            None,
1262            cx.clone(),
1263            |_| {},
1264        );
1265        server.process_name = process_name;
1266        let fake = FakeLanguageServer {
1267            binary,
1268            server: Arc::new({
1269                let mut server = LanguageServer::new_internal(
1270                    server_id,
1271                    server_name,
1272                    stdout_writer,
1273                    stdin_reader,
1274                    None::<async_pipe::PipeReader>,
1275                    Arc::new(Mutex::new(None)),
1276                    None,
1277                    root,
1278                    root,
1279                    None,
1280                    cx,
1281                    move |msg| {
1282                        notifications_tx
1283                            .try_send((
1284                                msg.method.to_string(),
1285                                msg.params.unwrap_or(Value::Null).to_string(),
1286                            ))
1287                            .ok();
1288                    },
1289                );
1290                server.process_name = name.as_str().into();
1291                server
1292            }),
1293            notifications_rx,
1294        };
1295        fake.handle_request::<request::Initialize, _, _>({
1296            let capabilities = capabilities;
1297            move |_, _| {
1298                let capabilities = capabilities.clone();
1299                let name = name.clone();
1300                async move {
1301                    Ok(InitializeResult {
1302                        capabilities,
1303                        server_info: Some(ServerInfo {
1304                            name,
1305                            ..Default::default()
1306                        }),
1307                    })
1308                }
1309            }
1310        });
1311
1312        (server, fake)
1313    }
1314
1315    #[cfg(target_os = "windows")]
1316    fn root_path() -> &'static Path {
1317        Path::new("C:\\")
1318    }
1319
1320    #[cfg(not(target_os = "windows"))]
1321    fn root_path() -> &'static Path {
1322        Path::new("/")
1323    }
1324}
1325
1326#[cfg(any(test, feature = "test-support"))]
1327impl LanguageServer {
1328    pub fn full_capabilities() -> ServerCapabilities {
1329        ServerCapabilities {
1330            document_highlight_provider: Some(OneOf::Left(true)),
1331            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
1332            document_formatting_provider: Some(OneOf::Left(true)),
1333            document_range_formatting_provider: Some(OneOf::Left(true)),
1334            definition_provider: Some(OneOf::Left(true)),
1335            implementation_provider: Some(ImplementationProviderCapability::Simple(true)),
1336            type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)),
1337            ..Default::default()
1338        }
1339    }
1340}
1341
1342#[cfg(any(test, feature = "test-support"))]
1343impl FakeLanguageServer {
1344    /// See [`LanguageServer::notify`].
1345    pub fn notify<T: notification::Notification>(&self, params: T::Params) {
1346        self.server.notify::<T>(params).ok();
1347    }
1348
1349    /// See [`LanguageServer::request`].
1350    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
1351    where
1352        T: request::Request,
1353        T::Result: 'static + Send,
1354    {
1355        self.server.executor.start_waiting();
1356        self.server.request::<T>(params).await
1357    }
1358
1359    /// Attempts [`Self::try_receive_notification`], unwrapping if it has not received the specified type yet.
1360    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
1361        self.server.executor.start_waiting();
1362        self.try_receive_notification::<T>().await.unwrap()
1363    }
1364
1365    /// Consumes the notification channel until it finds a notification for the specified type.
1366    pub async fn try_receive_notification<T: notification::Notification>(
1367        &mut self,
1368    ) -> Option<T::Params> {
1369        use futures::StreamExt as _;
1370
1371        loop {
1372            let (method, params) = self.notifications_rx.next().await?;
1373            if method == T::METHOD {
1374                return Some(serde_json::from_str::<T::Params>(&params).unwrap());
1375            } else {
1376                log::info!("skipping message in fake language server {:?}", params);
1377            }
1378        }
1379    }
1380
1381    /// Registers a handler for a specific kind of request. Removes any existing handler for specified request type.
1382    pub fn handle_request<T, F, Fut>(
1383        &self,
1384        mut handler: F,
1385    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1386    where
1387        T: 'static + request::Request,
1388        T::Params: 'static + Send,
1389        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
1390        Fut: 'static + Send + Future<Output = Result<T::Result>>,
1391    {
1392        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
1393        self.server.remove_request_handler::<T>();
1394        self.server
1395            .on_request::<T, _, _>(move |params, cx| {
1396                let result = handler(params, cx.clone());
1397                let responded_tx = responded_tx.clone();
1398                let executor = cx.background_executor().clone();
1399                async move {
1400                    executor.simulate_random_delay().await;
1401                    let result = result.await;
1402                    responded_tx.unbounded_send(()).ok();
1403                    result
1404                }
1405            })
1406            .detach();
1407        responded_rx
1408    }
1409
1410    /// Registers a handler for a specific kind of notification. Removes any existing handler for specified notification type.
1411    pub fn handle_notification<T, F>(
1412        &self,
1413        mut handler: F,
1414    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1415    where
1416        T: 'static + notification::Notification,
1417        T::Params: 'static + Send,
1418        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext),
1419    {
1420        let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded();
1421        self.server.remove_notification_handler::<T>();
1422        self.server
1423            .on_notification::<T, _>(move |params, cx| {
1424                handler(params, cx.clone());
1425                handled_tx.unbounded_send(()).ok();
1426            })
1427            .detach();
1428        handled_rx
1429    }
1430
1431    /// Removes any existing handler for specified notification type.
1432    pub fn remove_request_handler<T>(&mut self)
1433    where
1434        T: 'static + request::Request,
1435    {
1436        self.server.remove_request_handler::<T>();
1437    }
1438
1439    /// Simulate that the server has started work and notifies about its progress with the specified token.
1440    pub async fn start_progress(&self, token: impl Into<String>) {
1441        self.start_progress_with(token, Default::default()).await
1442    }
1443
1444    pub async fn start_progress_with(
1445        &self,
1446        token: impl Into<String>,
1447        progress: WorkDoneProgressBegin,
1448    ) {
1449        let token = token.into();
1450        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
1451            token: NumberOrString::String(token.clone()),
1452        })
1453        .await
1454        .unwrap();
1455        self.notify::<notification::Progress>(ProgressParams {
1456            token: NumberOrString::String(token),
1457            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(progress)),
1458        });
1459    }
1460
1461    /// Simulate that the server has completed work and notifies about that with the specified token.
1462    pub fn end_progress(&self, token: impl Into<String>) {
1463        self.notify::<notification::Progress>(ProgressParams {
1464            token: NumberOrString::String(token.into()),
1465            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
1466        });
1467    }
1468}
1469
1470#[cfg(test)]
1471mod tests {
1472    use super::*;
1473    use gpui::{SemanticVersion, TestAppContext};
1474    use std::str::FromStr;
1475
1476    #[ctor::ctor]
1477    fn init_logger() {
1478        if std::env::var("RUST_LOG").is_ok() {
1479            env_logger::init();
1480        }
1481    }
1482
1483    #[gpui::test]
1484    async fn test_fake(cx: &mut TestAppContext) {
1485        cx.update(|cx| {
1486            release_channel::init(SemanticVersion::default(), cx);
1487        });
1488        let (server, mut fake) = FakeLanguageServer::new(
1489            LanguageServerId(0),
1490            LanguageServerBinary {
1491                path: "path/to/language-server".into(),
1492                arguments: vec![],
1493                env: None,
1494            },
1495            "the-lsp".to_string(),
1496            Default::default(),
1497            cx.to_async(),
1498        );
1499
1500        let (message_tx, message_rx) = channel::unbounded();
1501        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
1502        server
1503            .on_notification::<notification::ShowMessage, _>(move |params, _| {
1504                message_tx.try_send(params).unwrap()
1505            })
1506            .detach();
1507        server
1508            .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
1509                diagnostics_tx.try_send(params).unwrap()
1510            })
1511            .detach();
1512
1513        let server = cx.update(|cx| server.initialize(None, cx)).await.unwrap();
1514        server
1515            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
1516                text_document: TextDocumentItem::new(
1517                    Url::from_str("file://a/b").unwrap(),
1518                    "rust".to_string(),
1519                    0,
1520                    "".to_string(),
1521                ),
1522            })
1523            .unwrap();
1524        assert_eq!(
1525            fake.receive_notification::<notification::DidOpenTextDocument>()
1526                .await
1527                .text_document
1528                .uri
1529                .as_str(),
1530            "file://a/b"
1531        );
1532
1533        fake.notify::<notification::ShowMessage>(ShowMessageParams {
1534            typ: MessageType::ERROR,
1535            message: "ok".to_string(),
1536        });
1537        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
1538            uri: Url::from_str("file://b/c").unwrap(),
1539            version: Some(5),
1540            diagnostics: vec![],
1541        });
1542        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
1543        assert_eq!(
1544            diagnostics_rx.recv().await.unwrap().uri.as_str(),
1545            "file://b/c"
1546        );
1547
1548        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
1549
1550        drop(server);
1551        fake.receive_notification::<notification::Exit>().await;
1552    }
1553
1554    #[gpui::test]
1555    fn test_deserialize_string_digit_id() {
1556        let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1557        let notification = serde_json::from_str::<AnyNotification>(json)
1558            .expect("message with string id should be parsed");
1559        let expected_id = RequestId::Str("2".to_string());
1560        assert_eq!(notification.id, Some(expected_id));
1561    }
1562
1563    #[gpui::test]
1564    fn test_deserialize_string_id() {
1565        let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1566        let notification = serde_json::from_str::<AnyNotification>(json)
1567            .expect("message with string id should be parsed");
1568        let expected_id = RequestId::Str("anythingAtAll".to_string());
1569        assert_eq!(notification.id, Some(expected_id));
1570    }
1571
1572    #[gpui::test]
1573    fn test_deserialize_int_id() {
1574        let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1575        let notification = serde_json::from_str::<AnyNotification>(json)
1576            .expect("message with string id should be parsed");
1577        let expected_id = RequestId::Int(2);
1578        assert_eq!(notification.id, Some(expected_id));
1579    }
1580
1581    #[test]
1582    fn test_serialize_has_no_nulls() {
1583        // Ensure we're not setting both result and error variants. (ticket #10595)
1584        let no_tag = Response::<u32> {
1585            jsonrpc: "",
1586            id: RequestId::Int(0),
1587            value: LspResult::Ok(None),
1588        };
1589        assert_eq!(
1590            serde_json::to_string(&no_tag).unwrap(),
1591            "{\"jsonrpc\":\"\",\"id\":0,\"result\":null}"
1592        );
1593        let no_tag = Response::<u32> {
1594            jsonrpc: "",
1595            id: RequestId::Int(0),
1596            value: LspResult::Error(None),
1597        };
1598        assert_eq!(
1599            serde_json::to_string(&no_tag).unwrap(),
1600            "{\"jsonrpc\":\"\",\"id\":0,\"error\":null}"
1601        );
1602    }
1603}