lsp.rs

   1mod input_handler;
   2
   3pub use lsp_types::request::*;
   4pub use lsp_types::*;
   5
   6use anyhow::{anyhow, Context, Result};
   7use collections::HashMap;
   8use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, Future, FutureExt};
   9use gpui::{AppContext, AsyncAppContext, BackgroundExecutor, SharedString, Task};
  10use parking_lot::{Mutex, RwLock};
  11use postage::{barrier, prelude::Stream};
  12use schemars::{
  13    gen::SchemaGenerator,
  14    schema::{InstanceType, Schema, SchemaObject},
  15    JsonSchema,
  16};
  17use serde::{de::DeserializeOwned, Deserialize, Serialize};
  18use serde_json::{json, value::RawValue, Value};
  19use smol::{
  20    channel,
  21    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
  22    process::{self, Child},
  23};
  24
  25#[cfg(target_os = "windows")]
  26use smol::process::windows::CommandExt;
  27
  28use std::{
  29    ffi::{OsStr, OsString},
  30    fmt,
  31    io::Write,
  32    ops::DerefMut,
  33    path::PathBuf,
  34    pin::Pin,
  35    sync::{
  36        atomic::{AtomicI32, Ordering::SeqCst},
  37        Arc, Weak,
  38    },
  39    task::Poll,
  40    time::{Duration, Instant},
  41};
  42use std::{path::Path, process::Stdio};
  43use util::{ResultExt, TryFutureExt};
  44
  45const JSON_RPC_VERSION: &str = "2.0";
  46const CONTENT_LEN_HEADER: &str = "Content-Length: ";
  47
  48const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  49const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
  50
  51type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, Value, AsyncAppContext)>;
  52type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
  53type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
  54
  55/// Kind of language server stdio given to an IO handler.
  56#[derive(Debug, Clone, Copy)]
  57pub enum IoKind {
  58    StdOut,
  59    StdIn,
  60    StdErr,
  61}
  62
  63/// Represents a launchable language server. This can either be a standalone binary or the path
  64/// to a runtime with arguments to instruct it to launch the actual language server file.
  65#[derive(Debug, Clone, Deserialize)]
  66pub struct LanguageServerBinary {
  67    pub path: PathBuf,
  68    pub arguments: Vec<OsString>,
  69    pub env: Option<HashMap<String, String>>,
  70}
  71
  72/// Configures the search (and installation) of language servers.
  73#[derive(Debug, Clone, Deserialize)]
  74pub struct LanguageServerBinaryOptions {
  75    /// Whether the adapter should look at the users system
  76    pub allow_path_lookup: bool,
  77    /// Whether the adapter should download its own version
  78    pub allow_binary_download: bool,
  79}
  80
  81/// A running language server process.
  82pub struct LanguageServer {
  83    server_id: LanguageServerId,
  84    next_id: AtomicI32,
  85    outbound_tx: channel::Sender<String>,
  86    name: LanguageServerName,
  87    process_name: Arc<str>,
  88    capabilities: RwLock<ServerCapabilities>,
  89    code_action_kinds: Option<Vec<CodeActionKind>>,
  90    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
  91    response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
  92    io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
  93    executor: BackgroundExecutor,
  94    #[allow(clippy::type_complexity)]
  95    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
  96    output_done_rx: Mutex<Option<barrier::Receiver>>,
  97    root_path: PathBuf,
  98    working_dir: PathBuf,
  99    server: Arc<Mutex<Option<Child>>>,
 100}
 101
 102/// Identifies a running language server.
 103#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 104#[repr(transparent)]
 105pub struct LanguageServerId(pub usize);
 106
 107impl LanguageServerId {
 108    pub fn from_proto(id: u64) -> Self {
 109        Self(id as usize)
 110    }
 111
 112    pub fn to_proto(self) -> u64 {
 113        self.0 as u64
 114    }
 115}
 116
 117/// A name of a language server.
 118#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
 119pub struct LanguageServerName(pub SharedString);
 120
 121impl std::fmt::Display for LanguageServerName {
 122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 123        std::fmt::Display::fmt(&self.0, f)
 124    }
 125}
 126
 127impl AsRef<str> for LanguageServerName {
 128    fn as_ref(&self) -> &str {
 129        self.0.as_ref()
 130    }
 131}
 132
 133impl AsRef<OsStr> for LanguageServerName {
 134    fn as_ref(&self) -> &OsStr {
 135        self.0.as_ref().as_ref()
 136    }
 137}
 138
 139impl JsonSchema for LanguageServerName {
 140    fn schema_name() -> String {
 141        "LanguageServerName".into()
 142    }
 143
 144    fn json_schema(_: &mut SchemaGenerator) -> Schema {
 145        SchemaObject {
 146            instance_type: Some(InstanceType::String.into()),
 147            ..Default::default()
 148        }
 149        .into()
 150    }
 151}
 152
 153impl LanguageServerName {
 154    pub const fn new_static(s: &'static str) -> Self {
 155        Self(SharedString::new_static(s))
 156    }
 157
 158    pub fn from_proto(s: String) -> Self {
 159        Self(s.into())
 160    }
 161}
 162
 163impl<'a> From<&'a str> for LanguageServerName {
 164    fn from(str: &'a str) -> LanguageServerName {
 165        LanguageServerName(str.to_string().into())
 166    }
 167}
 168
 169/// Handle to a language server RPC activity subscription.
 170pub enum Subscription {
 171    Notification {
 172        method: &'static str,
 173        notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>,
 174    },
 175    Io {
 176        id: i32,
 177        io_handlers: Option<Weak<Mutex<HashMap<i32, IoHandler>>>>,
 178    },
 179}
 180
 181/// Language server protocol RPC request message ID.
 182///
 183/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 184#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
 185#[serde(untagged)]
 186pub enum RequestId {
 187    Int(i32),
 188    Str(String),
 189}
 190
 191/// Language server protocol RPC request message.
 192///
 193/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 194#[derive(Serialize, Deserialize)]
 195pub struct Request<'a, T> {
 196    jsonrpc: &'static str,
 197    id: RequestId,
 198    method: &'a str,
 199    params: T,
 200}
 201
 202/// Language server protocol RPC request response message before it is deserialized into a concrete type.
 203#[derive(Serialize, Deserialize)]
 204struct AnyResponse<'a> {
 205    jsonrpc: &'a str,
 206    id: RequestId,
 207    #[serde(default)]
 208    error: Option<Error>,
 209    #[serde(borrow)]
 210    result: Option<&'a RawValue>,
 211}
 212
 213/// Language server protocol RPC request response message.
 214///
 215/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#responseMessage)
 216#[derive(Serialize)]
 217struct Response<T> {
 218    jsonrpc: &'static str,
 219    id: RequestId,
 220    #[serde(flatten)]
 221    value: LspResult<T>,
 222}
 223
 224#[derive(Serialize)]
 225#[serde(rename_all = "snake_case")]
 226enum LspResult<T> {
 227    #[serde(rename = "result")]
 228    Ok(Option<T>),
 229    Error(Option<Error>),
 230}
 231
 232/// Language server protocol RPC notification message.
 233///
 234/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
 235#[derive(Serialize, Deserialize)]
 236struct Notification<'a, T> {
 237    jsonrpc: &'static str,
 238    #[serde(borrow)]
 239    method: &'a str,
 240    params: T,
 241}
 242
 243/// Language server RPC notification message before it is deserialized into a concrete type.
 244#[derive(Debug, Clone, Deserialize)]
 245struct AnyNotification {
 246    #[serde(default)]
 247    id: Option<RequestId>,
 248    method: String,
 249    #[serde(default)]
 250    params: Option<Value>,
 251}
 252
 253#[derive(Debug, Serialize, Deserialize)]
 254struct Error {
 255    message: String,
 256}
 257
 258pub trait LspRequestFuture<O>: Future<Output = O> {
 259    fn id(&self) -> i32;
 260}
 261
 262struct LspRequest<F> {
 263    id: i32,
 264    request: F,
 265}
 266
 267impl<F> LspRequest<F> {
 268    pub fn new(id: i32, request: F) -> Self {
 269        Self { id, request }
 270    }
 271}
 272
 273impl<F: Future> Future for LspRequest<F> {
 274    type Output = F::Output;
 275
 276    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
 277        // SAFETY: This is standard pin projection, we're pinned so our fields must be pinned.
 278        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().request) };
 279        inner.poll(cx)
 280    }
 281}
 282
 283impl<F: Future> LspRequestFuture<F::Output> for LspRequest<F> {
 284    fn id(&self) -> i32 {
 285        self.id
 286    }
 287}
 288
 289/// Combined capabilities of the server and the adapter.
 290pub struct AdapterServerCapabilities {
 291    // Reported capabilities by the server
 292    pub server_capabilities: ServerCapabilities,
 293    // List of code actions supported by the LspAdapter matching the server
 294    pub code_action_kinds: Option<Vec<CodeActionKind>>,
 295}
 296
 297/// Experimental: Informs the end user about the state of the server
 298///
 299/// [Rust Analyzer Specification](https://github.com/rust-lang/rust-analyzer/blob/master/docs/dev/lsp-extensions.md#server-status)
 300#[derive(Debug)]
 301pub enum ServerStatus {}
 302
 303/// Other(String) variant to handle unknown values due to this still being experimental
 304#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
 305#[serde(rename_all = "camelCase")]
 306pub enum ServerHealthStatus {
 307    Ok,
 308    Warning,
 309    Error,
 310    Other(String),
 311}
 312
 313#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
 314#[serde(rename_all = "camelCase")]
 315pub struct ServerStatusParams {
 316    pub health: ServerHealthStatus,
 317    pub message: Option<String>,
 318}
 319
 320impl lsp_types::notification::Notification for ServerStatus {
 321    type Params = ServerStatusParams;
 322    const METHOD: &'static str = "experimental/serverStatus";
 323}
 324
 325impl LanguageServer {
 326    /// Starts a language server process.
 327    pub fn new(
 328        stderr_capture: Arc<Mutex<Option<String>>>,
 329        server_id: LanguageServerId,
 330        server_name: LanguageServerName,
 331        binary: LanguageServerBinary,
 332        root_path: &Path,
 333        code_action_kinds: Option<Vec<CodeActionKind>>,
 334        cx: AsyncAppContext,
 335    ) -> Result<Self> {
 336        let working_dir = if root_path.is_dir() {
 337            root_path
 338        } else {
 339            root_path.parent().unwrap_or_else(|| Path::new("/"))
 340        };
 341
 342        log::info!(
 343            "starting language server process. binary path: {:?}, working directory: {:?}, args: {:?}",
 344            binary.path,
 345            working_dir,
 346            &binary.arguments
 347        );
 348
 349        let mut command = process::Command::new(&binary.path);
 350        command
 351            .current_dir(working_dir)
 352            .args(&binary.arguments)
 353            .envs(binary.env.unwrap_or_default())
 354            .stdin(Stdio::piped())
 355            .stdout(Stdio::piped())
 356            .stderr(Stdio::piped())
 357            .kill_on_drop(true);
 358        #[cfg(windows)]
 359        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 360        let mut server = command.spawn().with_context(|| {
 361            format!(
 362                "failed to spawn command. path: {:?}, working directory: {:?}, args: {:?}",
 363                binary.path, working_dir, &binary.arguments
 364            )
 365        })?;
 366
 367        let stdin = server.stdin.take().unwrap();
 368        let stdout = server.stdout.take().unwrap();
 369        let stderr = server.stderr.take().unwrap();
 370        let mut server = Self::new_internal(
 371            server_id,
 372            server_name,
 373            stdin,
 374            stdout,
 375            Some(stderr),
 376            stderr_capture,
 377            Some(server),
 378            root_path,
 379            working_dir,
 380            code_action_kinds,
 381            cx,
 382            move |notification| {
 383                log::info!(
 384                    "Language server with id {} sent unhandled notification {}:\n{}",
 385                    server_id,
 386                    notification.method,
 387                    serde_json::to_string_pretty(&notification.params).unwrap(),
 388                );
 389            },
 390        );
 391
 392        if let Some(name) = binary.path.file_name() {
 393            server.process_name = name.to_string_lossy().into();
 394        }
 395
 396        Ok(server)
 397    }
 398
 399    #[allow(clippy::too_many_arguments)]
 400    fn new_internal<Stdin, Stdout, Stderr, F>(
 401        server_id: LanguageServerId,
 402        server_name: LanguageServerName,
 403        stdin: Stdin,
 404        stdout: Stdout,
 405        stderr: Option<Stderr>,
 406        stderr_capture: Arc<Mutex<Option<String>>>,
 407        server: Option<Child>,
 408        root_path: &Path,
 409        working_dir: &Path,
 410        code_action_kinds: Option<Vec<CodeActionKind>>,
 411        cx: AsyncAppContext,
 412        on_unhandled_notification: F,
 413    ) -> Self
 414    where
 415        Stdin: AsyncWrite + Unpin + Send + 'static,
 416        Stdout: AsyncRead + Unpin + Send + 'static,
 417        Stderr: AsyncRead + Unpin + Send + 'static,
 418        F: FnMut(AnyNotification) + 'static + Send + Sync + Clone,
 419    {
 420        let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
 421        let (output_done_tx, output_done_rx) = barrier::channel();
 422        let notification_handlers =
 423            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
 424        let response_handlers =
 425            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
 426        let io_handlers = Arc::new(Mutex::new(HashMap::default()));
 427
 428        let stdout_input_task = cx.spawn({
 429            let on_unhandled_notification = on_unhandled_notification.clone();
 430            let notification_handlers = notification_handlers.clone();
 431            let response_handlers = response_handlers.clone();
 432            let io_handlers = io_handlers.clone();
 433            move |cx| {
 434                Self::handle_input(
 435                    stdout,
 436                    on_unhandled_notification,
 437                    notification_handlers,
 438                    response_handlers,
 439                    io_handlers,
 440                    cx,
 441                )
 442                .log_err()
 443            }
 444        });
 445        let stderr_input_task = stderr
 446            .map(|stderr| {
 447                let io_handlers = io_handlers.clone();
 448                let stderr_captures = stderr_capture.clone();
 449                cx.spawn(|_| Self::handle_stderr(stderr, io_handlers, stderr_captures).log_err())
 450            })
 451            .unwrap_or_else(|| Task::Ready(Some(None)));
 452        let input_task = cx.spawn(|_| async move {
 453            let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
 454            stdout.or(stderr)
 455        });
 456        let output_task = cx.background_executor().spawn({
 457            Self::handle_output(
 458                stdin,
 459                outbound_rx,
 460                output_done_tx,
 461                response_handlers.clone(),
 462                io_handlers.clone(),
 463            )
 464            .log_err()
 465        });
 466
 467        Self {
 468            server_id,
 469            notification_handlers,
 470            response_handlers,
 471            io_handlers,
 472            name: server_name,
 473            process_name: Arc::default(),
 474            capabilities: Default::default(),
 475            code_action_kinds,
 476            next_id: Default::default(),
 477            outbound_tx,
 478            executor: cx.background_executor().clone(),
 479            io_tasks: Mutex::new(Some((input_task, output_task))),
 480            output_done_rx: Mutex::new(Some(output_done_rx)),
 481            root_path: root_path.to_path_buf(),
 482            working_dir: working_dir.to_path_buf(),
 483            server: Arc::new(Mutex::new(server)),
 484        }
 485    }
 486
 487    /// List of code action kinds this language server reports being able to emit.
 488    pub fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
 489        self.code_action_kinds.clone()
 490    }
 491
 492    async fn handle_input<Stdout, F>(
 493        stdout: Stdout,
 494        mut on_unhandled_notification: F,
 495        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 496        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 497        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 498        cx: AsyncAppContext,
 499    ) -> anyhow::Result<()>
 500    where
 501        Stdout: AsyncRead + Unpin + Send + 'static,
 502        F: FnMut(AnyNotification) + 'static + Send,
 503    {
 504        use smol::stream::StreamExt;
 505        let stdout = BufReader::new(stdout);
 506        let _clear_response_handlers = util::defer({
 507            let response_handlers = response_handlers.clone();
 508            move || {
 509                response_handlers.lock().take();
 510            }
 511        });
 512        let mut input_handler = input_handler::LspStdoutHandler::new(
 513            stdout,
 514            response_handlers,
 515            io_handlers,
 516            cx.background_executor().clone(),
 517        );
 518
 519        while let Some(msg) = input_handler.notifications_channel.next().await {
 520            {
 521                let mut notification_handlers = notification_handlers.lock();
 522                if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) {
 523                    handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone());
 524                } else {
 525                    drop(notification_handlers);
 526                    on_unhandled_notification(msg);
 527                }
 528            }
 529
 530            // Don't starve the main thread when receiving lots of notifications at once.
 531            smol::future::yield_now().await;
 532        }
 533        input_handler.loop_handle.await
 534    }
 535
 536    async fn handle_stderr<Stderr>(
 537        stderr: Stderr,
 538        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 539        stderr_capture: Arc<Mutex<Option<String>>>,
 540    ) -> anyhow::Result<()>
 541    where
 542        Stderr: AsyncRead + Unpin + Send + 'static,
 543    {
 544        let mut stderr = BufReader::new(stderr);
 545        let mut buffer = Vec::new();
 546
 547        loop {
 548            buffer.clear();
 549
 550            let bytes_read = stderr.read_until(b'\n', &mut buffer).await?;
 551            if bytes_read == 0 {
 552                return Ok(());
 553            }
 554
 555            if let Ok(message) = std::str::from_utf8(&buffer) {
 556                log::trace!("incoming stderr message:{message}");
 557                for handler in io_handlers.lock().values_mut() {
 558                    handler(IoKind::StdErr, message);
 559                }
 560
 561                if let Some(stderr) = stderr_capture.lock().as_mut() {
 562                    stderr.push_str(message);
 563                }
 564            }
 565
 566            // Don't starve the main thread when receiving lots of messages at once.
 567            smol::future::yield_now().await;
 568        }
 569    }
 570
 571    async fn handle_output<Stdin>(
 572        stdin: Stdin,
 573        outbound_rx: channel::Receiver<String>,
 574        output_done_tx: barrier::Sender,
 575        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 576        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
 577    ) -> anyhow::Result<()>
 578    where
 579        Stdin: AsyncWrite + Unpin + Send + 'static,
 580    {
 581        let mut stdin = BufWriter::new(stdin);
 582        let _clear_response_handlers = util::defer({
 583            let response_handlers = response_handlers.clone();
 584            move || {
 585                response_handlers.lock().take();
 586            }
 587        });
 588        let mut content_len_buffer = Vec::new();
 589        while let Ok(message) = outbound_rx.recv().await {
 590            log::trace!("outgoing message:{}", message);
 591            for handler in io_handlers.lock().values_mut() {
 592                handler(IoKind::StdIn, &message);
 593            }
 594
 595            content_len_buffer.clear();
 596            write!(content_len_buffer, "{}", message.len()).unwrap();
 597            stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
 598            stdin.write_all(&content_len_buffer).await?;
 599            stdin.write_all("\r\n\r\n".as_bytes()).await?;
 600            stdin.write_all(message.as_bytes()).await?;
 601            stdin.flush().await?;
 602        }
 603        drop(output_done_tx);
 604        Ok(())
 605    }
 606
 607    /// Initializes a language server by sending the `Initialize` request.
 608    /// Note that `options` is used directly to construct [`InitializeParams`], which is why it is owned.
 609    ///
 610    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#initialize)
 611    pub fn initialize(
 612        mut self,
 613        options: Option<Value>,
 614        cx: &AppContext,
 615    ) -> Task<Result<Arc<Self>>> {
 616        let root_uri = Url::from_file_path(&self.working_dir).unwrap();
 617        #[allow(deprecated)]
 618        let params = InitializeParams {
 619            process_id: None,
 620            root_path: None,
 621            root_uri: Some(root_uri.clone()),
 622            initialization_options: options,
 623            capabilities: ClientCapabilities {
 624                workspace: Some(WorkspaceClientCapabilities {
 625                    configuration: Some(true),
 626                    did_change_watched_files: Some(DidChangeWatchedFilesClientCapabilities {
 627                        dynamic_registration: Some(true),
 628                        relative_pattern_support: Some(true),
 629                    }),
 630                    did_change_configuration: Some(DynamicRegistrationClientCapabilities {
 631                        dynamic_registration: Some(true),
 632                    }),
 633                    workspace_folders: Some(true),
 634                    symbol: Some(WorkspaceSymbolClientCapabilities {
 635                        resolve_support: None,
 636                        ..WorkspaceSymbolClientCapabilities::default()
 637                    }),
 638                    inlay_hint: Some(InlayHintWorkspaceClientCapabilities {
 639                        refresh_support: Some(true),
 640                    }),
 641                    diagnostic: Some(DiagnosticWorkspaceClientCapabilities {
 642                        refresh_support: None,
 643                    }),
 644                    workspace_edit: Some(WorkspaceEditClientCapabilities {
 645                        resource_operations: Some(vec![
 646                            ResourceOperationKind::Create,
 647                            ResourceOperationKind::Rename,
 648                            ResourceOperationKind::Delete,
 649                        ]),
 650                        document_changes: Some(true),
 651                        snippet_edit_support: Some(true),
 652                        ..WorkspaceEditClientCapabilities::default()
 653                    }),
 654                    ..Default::default()
 655                }),
 656                text_document: Some(TextDocumentClientCapabilities {
 657                    definition: Some(GotoCapability {
 658                        link_support: Some(true),
 659                        dynamic_registration: None,
 660                    }),
 661                    code_action: Some(CodeActionClientCapabilities {
 662                        code_action_literal_support: Some(CodeActionLiteralSupport {
 663                            code_action_kind: CodeActionKindLiteralSupport {
 664                                value_set: vec![
 665                                    CodeActionKind::REFACTOR.as_str().into(),
 666                                    CodeActionKind::QUICKFIX.as_str().into(),
 667                                    CodeActionKind::SOURCE.as_str().into(),
 668                                ],
 669                            },
 670                        }),
 671                        data_support: Some(true),
 672                        resolve_support: Some(CodeActionCapabilityResolveSupport {
 673                            properties: vec![
 674                                "kind".to_string(),
 675                                "diagnostics".to_string(),
 676                                "isPreferred".to_string(),
 677                                "disabled".to_string(),
 678                                "edit".to_string(),
 679                                "command".to_string(),
 680                            ],
 681                        }),
 682                        ..Default::default()
 683                    }),
 684                    completion: Some(CompletionClientCapabilities {
 685                        completion_item: Some(CompletionItemCapability {
 686                            snippet_support: Some(true),
 687                            resolve_support: Some(CompletionItemCapabilityResolveSupport {
 688                                properties: vec![
 689                                    "additionalTextEdits".to_string(),
 690                                    "command".to_string(),
 691                                    "documentation".to_string(),
 692                                    // NB: Do not have this resolved, otherwise Zed becomes slow to complete things
 693                                    // "textEdit".to_string(),
 694                                ],
 695                            }),
 696                            insert_replace_support: Some(true),
 697                            label_details_support: Some(true),
 698                            ..Default::default()
 699                        }),
 700                        completion_list: Some(CompletionListCapability {
 701                            item_defaults: Some(vec![
 702                                "commitCharacters".to_owned(),
 703                                "editRange".to_owned(),
 704                                "insertTextMode".to_owned(),
 705                                "data".to_owned(),
 706                            ]),
 707                        }),
 708                        context_support: Some(true),
 709                        ..Default::default()
 710                    }),
 711                    rename: Some(RenameClientCapabilities {
 712                        prepare_support: Some(true),
 713                        ..Default::default()
 714                    }),
 715                    hover: Some(HoverClientCapabilities {
 716                        content_format: Some(vec![MarkupKind::Markdown]),
 717                        dynamic_registration: None,
 718                    }),
 719                    inlay_hint: Some(InlayHintClientCapabilities {
 720                        resolve_support: Some(InlayHintResolveClientCapabilities {
 721                            properties: vec![
 722                                "textEdits".to_string(),
 723                                "tooltip".to_string(),
 724                                "label.tooltip".to_string(),
 725                                "label.location".to_string(),
 726                                "label.command".to_string(),
 727                            ],
 728                        }),
 729                        dynamic_registration: Some(false),
 730                    }),
 731                    publish_diagnostics: Some(PublishDiagnosticsClientCapabilities {
 732                        related_information: Some(true),
 733                        ..Default::default()
 734                    }),
 735                    formatting: Some(DynamicRegistrationClientCapabilities {
 736                        dynamic_registration: Some(true),
 737                    }),
 738                    range_formatting: Some(DynamicRegistrationClientCapabilities {
 739                        dynamic_registration: Some(true),
 740                    }),
 741                    on_type_formatting: Some(DynamicRegistrationClientCapabilities {
 742                        dynamic_registration: Some(true),
 743                    }),
 744                    signature_help: Some(SignatureHelpClientCapabilities {
 745                        signature_information: Some(SignatureInformationSettings {
 746                            documentation_format: Some(vec![
 747                                MarkupKind::Markdown,
 748                                MarkupKind::PlainText,
 749                            ]),
 750                            parameter_information: Some(ParameterInformationSettings {
 751                                label_offset_support: Some(true),
 752                            }),
 753                            active_parameter_support: Some(true),
 754                        }),
 755                        ..SignatureHelpClientCapabilities::default()
 756                    }),
 757                    synchronization: Some(TextDocumentSyncClientCapabilities {
 758                        did_save: Some(true),
 759                        ..TextDocumentSyncClientCapabilities::default()
 760                    }),
 761                    ..TextDocumentClientCapabilities::default()
 762                }),
 763                experimental: Some(json!({
 764                    "serverStatusNotification": true,
 765                })),
 766                window: Some(WindowClientCapabilities {
 767                    work_done_progress: Some(true),
 768                    ..Default::default()
 769                }),
 770                general: None,
 771            },
 772            trace: None,
 773            workspace_folders: Some(vec![WorkspaceFolder {
 774                uri: root_uri,
 775                name: Default::default(),
 776            }]),
 777            client_info: release_channel::ReleaseChannel::try_global(cx).map(|release_channel| {
 778                ClientInfo {
 779                    name: release_channel.display_name().to_string(),
 780                    version: Some(release_channel::AppVersion::global(cx).to_string()),
 781                }
 782            }),
 783            locale: None,
 784            ..Default::default()
 785        };
 786
 787        cx.spawn(|_| async move {
 788            let response = self.request::<request::Initialize>(params).await?;
 789            if let Some(info) = response.server_info {
 790                self.process_name = info.name.into();
 791            }
 792            self.capabilities = RwLock::new(response.capabilities);
 793
 794            self.notify::<notification::Initialized>(InitializedParams {})?;
 795            Ok(Arc::new(self))
 796        })
 797    }
 798
 799    /// Sends a shutdown request to the language server process and prepares the [`LanguageServer`] to be dropped.
 800    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
 801        if let Some(tasks) = self.io_tasks.lock().take() {
 802            let response_handlers = self.response_handlers.clone();
 803            let next_id = AtomicI32::new(self.next_id.load(SeqCst));
 804            let outbound_tx = self.outbound_tx.clone();
 805            let executor = self.executor.clone();
 806            let mut output_done = self.output_done_rx.lock().take().unwrap();
 807            let shutdown_request = Self::request_internal::<request::Shutdown>(
 808                &next_id,
 809                &response_handlers,
 810                &outbound_tx,
 811                &executor,
 812                (),
 813            );
 814            let exit = Self::notify_internal::<notification::Exit>(&outbound_tx, ());
 815            outbound_tx.close();
 816
 817            let server = self.server.clone();
 818            let name = self.name.clone();
 819            let mut timer = self.executor.timer(SERVER_SHUTDOWN_TIMEOUT).fuse();
 820            Some(
 821                async move {
 822                    log::debug!("language server shutdown started");
 823
 824                    select! {
 825                        request_result = shutdown_request.fuse() => {
 826                            request_result?;
 827                        }
 828
 829                        _ = timer => {
 830                            log::info!("timeout waiting for language server {name} to shutdown");
 831                        },
 832                    }
 833
 834                    response_handlers.lock().take();
 835                    exit?;
 836                    output_done.recv().await;
 837                    server.lock().take().map(|mut child| child.kill());
 838                    log::debug!("language server shutdown finished");
 839
 840                    drop(tasks);
 841                    anyhow::Ok(())
 842                }
 843                .log_err(),
 844            )
 845        } else {
 846            None
 847        }
 848    }
 849
 850    /// Register a handler to handle incoming LSP notifications.
 851    ///
 852    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
 853    #[must_use]
 854    pub fn on_notification<T, F>(&self, f: F) -> Subscription
 855    where
 856        T: notification::Notification,
 857        F: 'static + Send + FnMut(T::Params, AsyncAppContext),
 858    {
 859        self.on_custom_notification(T::METHOD, f)
 860    }
 861
 862    /// Register a handler to handle incoming LSP requests.
 863    ///
 864    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 865    #[must_use]
 866    pub fn on_request<T, F, Fut>(&self, f: F) -> Subscription
 867    where
 868        T: request::Request,
 869        T::Params: 'static + Send,
 870        F: 'static + FnMut(T::Params, AsyncAppContext) -> Fut + Send,
 871        Fut: 'static + Future<Output = Result<T::Result>>,
 872    {
 873        self.on_custom_request(T::METHOD, f)
 874    }
 875
 876    /// Registers a handler to inspect all language server process stdio.
 877    #[must_use]
 878    pub fn on_io<F>(&self, f: F) -> Subscription
 879    where
 880        F: 'static + Send + FnMut(IoKind, &str),
 881    {
 882        let id = self.next_id.fetch_add(1, SeqCst);
 883        self.io_handlers.lock().insert(id, Box::new(f));
 884        Subscription::Io {
 885            id,
 886            io_handlers: Some(Arc::downgrade(&self.io_handlers)),
 887        }
 888    }
 889
 890    /// Removes a request handler registers via [`Self::on_request`].
 891    pub fn remove_request_handler<T: request::Request>(&self) {
 892        self.notification_handlers.lock().remove(T::METHOD);
 893    }
 894
 895    /// Removes a notification handler registers via [`Self::on_notification`].
 896    pub fn remove_notification_handler<T: notification::Notification>(&self) {
 897        self.notification_handlers.lock().remove(T::METHOD);
 898    }
 899
 900    /// Checks if a notification handler has been registered via [`Self::on_notification`].
 901    pub fn has_notification_handler<T: notification::Notification>(&self) -> bool {
 902        self.notification_handlers.lock().contains_key(T::METHOD)
 903    }
 904
 905    #[must_use]
 906    fn on_custom_notification<Params, F>(&self, method: &'static str, mut f: F) -> Subscription
 907    where
 908        F: 'static + FnMut(Params, AsyncAppContext) + Send,
 909        Params: DeserializeOwned,
 910    {
 911        let prev_handler = self.notification_handlers.lock().insert(
 912            method,
 913            Box::new(move |_, params, cx| {
 914                if let Some(params) = serde_json::from_value(params).log_err() {
 915                    f(params, cx);
 916                }
 917            }),
 918        );
 919        assert!(
 920            prev_handler.is_none(),
 921            "registered multiple handlers for the same LSP method"
 922        );
 923        Subscription::Notification {
 924            method,
 925            notification_handlers: Some(self.notification_handlers.clone()),
 926        }
 927    }
 928
 929    #[must_use]
 930    fn on_custom_request<Params, Res, Fut, F>(&self, method: &'static str, mut f: F) -> Subscription
 931    where
 932        F: 'static + FnMut(Params, AsyncAppContext) -> Fut + Send,
 933        Fut: 'static + Future<Output = Result<Res>>,
 934        Params: DeserializeOwned + Send + 'static,
 935        Res: Serialize,
 936    {
 937        let outbound_tx = self.outbound_tx.clone();
 938        let prev_handler = self.notification_handlers.lock().insert(
 939            method,
 940            Box::new(move |id, params, cx| {
 941                if let Some(id) = id {
 942                    match serde_json::from_value(params) {
 943                        Ok(params) => {
 944                            let response = f(params, cx.clone());
 945                            cx.foreground_executor()
 946                                .spawn({
 947                                    let outbound_tx = outbound_tx.clone();
 948                                    async move {
 949                                        let response = match response.await {
 950                                            Ok(result) => Response {
 951                                                jsonrpc: JSON_RPC_VERSION,
 952                                                id,
 953                                                value: LspResult::Ok(Some(result)),
 954                                            },
 955                                            Err(error) => Response {
 956                                                jsonrpc: JSON_RPC_VERSION,
 957                                                id,
 958                                                value: LspResult::Error(Some(Error {
 959                                                    message: error.to_string(),
 960                                                })),
 961                                            },
 962                                        };
 963                                        if let Some(response) =
 964                                            serde_json::to_string(&response).log_err()
 965                                        {
 966                                            outbound_tx.try_send(response).ok();
 967                                        }
 968                                    }
 969                                })
 970                                .detach();
 971                        }
 972
 973                        Err(error) => {
 974                            log::error!("error deserializing {} request: {:?}", method, error);
 975                            let response = AnyResponse {
 976                                jsonrpc: JSON_RPC_VERSION,
 977                                id,
 978                                result: None,
 979                                error: Some(Error {
 980                                    message: error.to_string(),
 981                                }),
 982                            };
 983                            if let Some(response) = serde_json::to_string(&response).log_err() {
 984                                outbound_tx.try_send(response).ok();
 985                            }
 986                        }
 987                    }
 988                }
 989            }),
 990        );
 991        assert!(
 992            prev_handler.is_none(),
 993            "registered multiple handlers for the same LSP method"
 994        );
 995        Subscription::Notification {
 996            method,
 997            notification_handlers: Some(self.notification_handlers.clone()),
 998        }
 999    }
1000
1001    /// Get the name of the running language server.
1002    pub fn name(&self) -> LanguageServerName {
1003        self.name.clone()
1004    }
1005
1006    pub fn process_name(&self) -> &str {
1007        &self.process_name
1008    }
1009
1010    /// Get the reported capabilities of the running language server.
1011    pub fn capabilities(&self) -> ServerCapabilities {
1012        self.capabilities.read().clone()
1013    }
1014
1015    /// Get the reported capabilities of the running language server and
1016    /// what we know on the client/adapter-side of its capabilities.
1017    pub fn adapter_server_capabilities(&self) -> AdapterServerCapabilities {
1018        AdapterServerCapabilities {
1019            server_capabilities: self.capabilities(),
1020            code_action_kinds: self.code_action_kinds(),
1021        }
1022    }
1023
1024    pub fn update_capabilities(&self, update: impl FnOnce(&mut ServerCapabilities)) {
1025        update(self.capabilities.write().deref_mut());
1026    }
1027
1028    /// Get the id of the running language server.
1029    pub fn server_id(&self) -> LanguageServerId {
1030        self.server_id
1031    }
1032
1033    /// Get the root path of the project the language server is running against.
1034    pub fn root_path(&self) -> &PathBuf {
1035        &self.root_path
1036    }
1037
1038    /// Sends a RPC request to the language server.
1039    ///
1040    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
1041    pub fn request<T: request::Request>(
1042        &self,
1043        params: T::Params,
1044    ) -> impl LspRequestFuture<Result<T::Result>>
1045    where
1046        T::Result: 'static + Send,
1047    {
1048        Self::request_internal::<T>(
1049            &self.next_id,
1050            &self.response_handlers,
1051            &self.outbound_tx,
1052            &self.executor,
1053            params,
1054        )
1055    }
1056
1057    fn request_internal<T: request::Request>(
1058        next_id: &AtomicI32,
1059        response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>,
1060        outbound_tx: &channel::Sender<String>,
1061        executor: &BackgroundExecutor,
1062        params: T::Params,
1063    ) -> impl LspRequestFuture<Result<T::Result>>
1064    where
1065        T::Result: 'static + Send,
1066    {
1067        let id = next_id.fetch_add(1, SeqCst);
1068        let message = serde_json::to_string(&Request {
1069            jsonrpc: JSON_RPC_VERSION,
1070            id: RequestId::Int(id),
1071            method: T::METHOD,
1072            params,
1073        })
1074        .unwrap();
1075
1076        let (tx, rx) = oneshot::channel();
1077        let handle_response = response_handlers
1078            .lock()
1079            .as_mut()
1080            .ok_or_else(|| anyhow!("server shut down"))
1081            .map(|handlers| {
1082                let executor = executor.clone();
1083                handlers.insert(
1084                    RequestId::Int(id),
1085                    Box::new(move |result| {
1086                        executor
1087                            .spawn(async move {
1088                                let response = match result {
1089                                    Ok(response) => match serde_json::from_str(&response) {
1090                                        Ok(deserialized) => Ok(deserialized),
1091                                        Err(error) => {
1092                                            log::error!("failed to deserialize response from language server: {}. response from language server: {:?}", error, response);
1093                                            Err(error).context("failed to deserialize response")
1094                                        }
1095                                    }
1096                                    Err(error) => Err(anyhow!("{}", error.message)),
1097                                };
1098                                _ = tx.send(response);
1099                            })
1100                            .detach();
1101                    }),
1102                );
1103            });
1104
1105        let send = outbound_tx
1106            .try_send(message)
1107            .context("failed to write to language server's stdin");
1108
1109        let outbound_tx = outbound_tx.downgrade();
1110        let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse();
1111        let started = Instant::now();
1112        LspRequest::new(id, async move {
1113            handle_response?;
1114            send?;
1115
1116            let cancel_on_drop = util::defer(move || {
1117                if let Some(outbound_tx) = outbound_tx.upgrade() {
1118                    Self::notify_internal::<notification::Cancel>(
1119                        &outbound_tx,
1120                        CancelParams {
1121                            id: NumberOrString::Number(id),
1122                        },
1123                    )
1124                    .log_err();
1125                }
1126            });
1127
1128            let method = T::METHOD;
1129            select! {
1130                response = rx.fuse() => {
1131                    let elapsed = started.elapsed();
1132                    log::trace!("Took {elapsed:?} to receive response to {method:?} id {id}");
1133                    cancel_on_drop.abort();
1134                    response?
1135                }
1136
1137                _ = timeout => {
1138                    log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}");
1139                    anyhow::bail!("LSP request timeout");
1140                }
1141            }
1142        })
1143    }
1144
1145    /// Sends a RPC notification to the language server.
1146    ///
1147    /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage)
1148    pub fn notify<T: notification::Notification>(&self, params: T::Params) -> Result<()> {
1149        Self::notify_internal::<T>(&self.outbound_tx, params)
1150    }
1151
1152    fn notify_internal<T: notification::Notification>(
1153        outbound_tx: &channel::Sender<String>,
1154        params: T::Params,
1155    ) -> Result<()> {
1156        let message = serde_json::to_string(&Notification {
1157            jsonrpc: JSON_RPC_VERSION,
1158            method: T::METHOD,
1159            params,
1160        })
1161        .unwrap();
1162        outbound_tx.try_send(message)?;
1163        Ok(())
1164    }
1165}
1166
1167impl Drop for LanguageServer {
1168    fn drop(&mut self) {
1169        if let Some(shutdown) = self.shutdown() {
1170            self.executor.spawn(shutdown).detach();
1171        }
1172    }
1173}
1174
1175impl Subscription {
1176    /// Detaching a subscription handle prevents it from unsubscribing on drop.
1177    pub fn detach(&mut self) {
1178        match self {
1179            Subscription::Notification {
1180                notification_handlers,
1181                ..
1182            } => *notification_handlers = None,
1183            Subscription::Io { io_handlers, .. } => *io_handlers = None,
1184        }
1185    }
1186}
1187
1188impl fmt::Display for LanguageServerId {
1189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1190        self.0.fmt(f)
1191    }
1192}
1193
1194impl fmt::Debug for LanguageServer {
1195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1196        f.debug_struct("LanguageServer")
1197            .field("id", &self.server_id.0)
1198            .field("name", &self.name)
1199            .finish_non_exhaustive()
1200    }
1201}
1202
1203impl Drop for Subscription {
1204    fn drop(&mut self) {
1205        match self {
1206            Subscription::Notification {
1207                method,
1208                notification_handlers,
1209            } => {
1210                if let Some(handlers) = notification_handlers {
1211                    handlers.lock().remove(method);
1212                }
1213            }
1214            Subscription::Io { id, io_handlers } => {
1215                if let Some(io_handlers) = io_handlers.as_ref().and_then(|h| h.upgrade()) {
1216                    io_handlers.lock().remove(id);
1217                }
1218            }
1219        }
1220    }
1221}
1222
1223/// Mock language server for use in tests.
1224#[cfg(any(test, feature = "test-support"))]
1225#[derive(Clone)]
1226pub struct FakeLanguageServer {
1227    pub binary: LanguageServerBinary,
1228    pub server: Arc<LanguageServer>,
1229    notifications_rx: channel::Receiver<(String, String)>,
1230}
1231
1232#[cfg(any(test, feature = "test-support"))]
1233impl FakeLanguageServer {
1234    /// Construct a fake language server.
1235    pub fn new(
1236        server_id: LanguageServerId,
1237        binary: LanguageServerBinary,
1238        name: String,
1239        capabilities: ServerCapabilities,
1240        cx: AsyncAppContext,
1241    ) -> (LanguageServer, FakeLanguageServer) {
1242        let (stdin_writer, stdin_reader) = async_pipe::pipe();
1243        let (stdout_writer, stdout_reader) = async_pipe::pipe();
1244        let (notifications_tx, notifications_rx) = channel::unbounded();
1245
1246        let root = Self::root_path();
1247
1248        let server_name = LanguageServerName(name.clone().into());
1249        let process_name = Arc::from(name.as_str());
1250        let mut server = LanguageServer::new_internal(
1251            server_id,
1252            server_name.clone(),
1253            stdin_writer,
1254            stdout_reader,
1255            None::<async_pipe::PipeReader>,
1256            Arc::new(Mutex::new(None)),
1257            None,
1258            root,
1259            root,
1260            None,
1261            cx.clone(),
1262            |_| {},
1263        );
1264        server.process_name = process_name;
1265        let fake = FakeLanguageServer {
1266            binary,
1267            server: Arc::new({
1268                let mut server = LanguageServer::new_internal(
1269                    server_id,
1270                    server_name,
1271                    stdout_writer,
1272                    stdin_reader,
1273                    None::<async_pipe::PipeReader>,
1274                    Arc::new(Mutex::new(None)),
1275                    None,
1276                    root,
1277                    root,
1278                    None,
1279                    cx,
1280                    move |msg| {
1281                        notifications_tx
1282                            .try_send((
1283                                msg.method.to_string(),
1284                                msg.params.unwrap_or(Value::Null).to_string(),
1285                            ))
1286                            .ok();
1287                    },
1288                );
1289                server.process_name = name.as_str().into();
1290                server
1291            }),
1292            notifications_rx,
1293        };
1294        fake.handle_request::<request::Initialize, _, _>({
1295            let capabilities = capabilities;
1296            move |_, _| {
1297                let capabilities = capabilities.clone();
1298                let name = name.clone();
1299                async move {
1300                    Ok(InitializeResult {
1301                        capabilities,
1302                        server_info: Some(ServerInfo {
1303                            name,
1304                            ..Default::default()
1305                        }),
1306                    })
1307                }
1308            }
1309        });
1310
1311        (server, fake)
1312    }
1313
1314    #[cfg(target_os = "windows")]
1315    fn root_path() -> &'static Path {
1316        Path::new("C:\\")
1317    }
1318
1319    #[cfg(not(target_os = "windows"))]
1320    fn root_path() -> &'static Path {
1321        Path::new("/")
1322    }
1323}
1324
1325#[cfg(any(test, feature = "test-support"))]
1326impl LanguageServer {
1327    pub fn full_capabilities() -> ServerCapabilities {
1328        ServerCapabilities {
1329            document_highlight_provider: Some(OneOf::Left(true)),
1330            code_action_provider: Some(CodeActionProviderCapability::Simple(true)),
1331            document_formatting_provider: Some(OneOf::Left(true)),
1332            document_range_formatting_provider: Some(OneOf::Left(true)),
1333            definition_provider: Some(OneOf::Left(true)),
1334            implementation_provider: Some(ImplementationProviderCapability::Simple(true)),
1335            type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)),
1336            ..Default::default()
1337        }
1338    }
1339}
1340
1341#[cfg(any(test, feature = "test-support"))]
1342impl FakeLanguageServer {
1343    /// See [`LanguageServer::notify`].
1344    pub fn notify<T: notification::Notification>(&self, params: T::Params) {
1345        self.server.notify::<T>(params).ok();
1346    }
1347
1348    /// See [`LanguageServer::request`].
1349    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
1350    where
1351        T: request::Request,
1352        T::Result: 'static + Send,
1353    {
1354        self.server.executor.start_waiting();
1355        self.server.request::<T>(params).await
1356    }
1357
1358    /// Attempts [`Self::try_receive_notification`], unwrapping if it has not received the specified type yet.
1359    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
1360        self.server.executor.start_waiting();
1361        self.try_receive_notification::<T>().await.unwrap()
1362    }
1363
1364    /// Consumes the notification channel until it finds a notification for the specified type.
1365    pub async fn try_receive_notification<T: notification::Notification>(
1366        &mut self,
1367    ) -> Option<T::Params> {
1368        use futures::StreamExt as _;
1369
1370        loop {
1371            let (method, params) = self.notifications_rx.next().await?;
1372            if method == T::METHOD {
1373                return Some(serde_json::from_str::<T::Params>(&params).unwrap());
1374            } else {
1375                log::info!("skipping message in fake language server {:?}", params);
1376            }
1377        }
1378    }
1379
1380    /// Registers a handler for a specific kind of request. Removes any existing handler for specified request type.
1381    pub fn handle_request<T, F, Fut>(
1382        &self,
1383        mut handler: F,
1384    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1385    where
1386        T: 'static + request::Request,
1387        T::Params: 'static + Send,
1388        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut,
1389        Fut: 'static + Send + Future<Output = Result<T::Result>>,
1390    {
1391        let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
1392        self.server.remove_request_handler::<T>();
1393        self.server
1394            .on_request::<T, _, _>(move |params, cx| {
1395                let result = handler(params, cx.clone());
1396                let responded_tx = responded_tx.clone();
1397                let executor = cx.background_executor().clone();
1398                async move {
1399                    executor.simulate_random_delay().await;
1400                    let result = result.await;
1401                    responded_tx.unbounded_send(()).ok();
1402                    result
1403                }
1404            })
1405            .detach();
1406        responded_rx
1407    }
1408
1409    /// Registers a handler for a specific kind of notification. Removes any existing handler for specified notification type.
1410    pub fn handle_notification<T, F>(
1411        &self,
1412        mut handler: F,
1413    ) -> futures::channel::mpsc::UnboundedReceiver<()>
1414    where
1415        T: 'static + notification::Notification,
1416        T::Params: 'static + Send,
1417        F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext),
1418    {
1419        let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded();
1420        self.server.remove_notification_handler::<T>();
1421        self.server
1422            .on_notification::<T, _>(move |params, cx| {
1423                handler(params, cx.clone());
1424                handled_tx.unbounded_send(()).ok();
1425            })
1426            .detach();
1427        handled_rx
1428    }
1429
1430    /// Removes any existing handler for specified notification type.
1431    pub fn remove_request_handler<T>(&mut self)
1432    where
1433        T: 'static + request::Request,
1434    {
1435        self.server.remove_request_handler::<T>();
1436    }
1437
1438    /// Simulate that the server has started work and notifies about its progress with the specified token.
1439    pub async fn start_progress(&self, token: impl Into<String>) {
1440        self.start_progress_with(token, Default::default()).await
1441    }
1442
1443    pub async fn start_progress_with(
1444        &self,
1445        token: impl Into<String>,
1446        progress: WorkDoneProgressBegin,
1447    ) {
1448        let token = token.into();
1449        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
1450            token: NumberOrString::String(token.clone()),
1451        })
1452        .await
1453        .unwrap();
1454        self.notify::<notification::Progress>(ProgressParams {
1455            token: NumberOrString::String(token),
1456            value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(progress)),
1457        });
1458    }
1459
1460    /// Simulate that the server has completed work and notifies about that with the specified token.
1461    pub fn end_progress(&self, token: impl Into<String>) {
1462        self.notify::<notification::Progress>(ProgressParams {
1463            token: NumberOrString::String(token.into()),
1464            value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),
1465        });
1466    }
1467}
1468
1469#[cfg(test)]
1470mod tests {
1471    use super::*;
1472    use gpui::{SemanticVersion, TestAppContext};
1473    use std::str::FromStr;
1474
1475    #[ctor::ctor]
1476    fn init_logger() {
1477        if std::env::var("RUST_LOG").is_ok() {
1478            env_logger::init();
1479        }
1480    }
1481
1482    #[gpui::test]
1483    async fn test_fake(cx: &mut TestAppContext) {
1484        cx.update(|cx| {
1485            release_channel::init(SemanticVersion::default(), cx);
1486        });
1487        let (server, mut fake) = FakeLanguageServer::new(
1488            LanguageServerId(0),
1489            LanguageServerBinary {
1490                path: "path/to/language-server".into(),
1491                arguments: vec![],
1492                env: None,
1493            },
1494            "the-lsp".to_string(),
1495            Default::default(),
1496            cx.to_async(),
1497        );
1498
1499        let (message_tx, message_rx) = channel::unbounded();
1500        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
1501        server
1502            .on_notification::<notification::ShowMessage, _>(move |params, _| {
1503                message_tx.try_send(params).unwrap()
1504            })
1505            .detach();
1506        server
1507            .on_notification::<notification::PublishDiagnostics, _>(move |params, _| {
1508                diagnostics_tx.try_send(params).unwrap()
1509            })
1510            .detach();
1511
1512        let server = cx.update(|cx| server.initialize(None, cx)).await.unwrap();
1513        server
1514            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
1515                text_document: TextDocumentItem::new(
1516                    Url::from_str("file://a/b").unwrap(),
1517                    "rust".to_string(),
1518                    0,
1519                    "".to_string(),
1520                ),
1521            })
1522            .unwrap();
1523        assert_eq!(
1524            fake.receive_notification::<notification::DidOpenTextDocument>()
1525                .await
1526                .text_document
1527                .uri
1528                .as_str(),
1529            "file://a/b"
1530        );
1531
1532        fake.notify::<notification::ShowMessage>(ShowMessageParams {
1533            typ: MessageType::ERROR,
1534            message: "ok".to_string(),
1535        });
1536        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
1537            uri: Url::from_str("file://b/c").unwrap(),
1538            version: Some(5),
1539            diagnostics: vec![],
1540        });
1541        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
1542        assert_eq!(
1543            diagnostics_rx.recv().await.unwrap().uri.as_str(),
1544            "file://b/c"
1545        );
1546
1547        fake.handle_request::<request::Shutdown, _, _>(|_, _| async move { Ok(()) });
1548
1549        drop(server);
1550        fake.receive_notification::<notification::Exit>().await;
1551    }
1552
1553    #[gpui::test]
1554    fn test_deserialize_string_digit_id() {
1555        let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1556        let notification = serde_json::from_str::<AnyNotification>(json)
1557            .expect("message with string id should be parsed");
1558        let expected_id = RequestId::Str("2".to_string());
1559        assert_eq!(notification.id, Some(expected_id));
1560    }
1561
1562    #[gpui::test]
1563    fn test_deserialize_string_id() {
1564        let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1565        let notification = serde_json::from_str::<AnyNotification>(json)
1566            .expect("message with string id should be parsed");
1567        let expected_id = RequestId::Str("anythingAtAll".to_string());
1568        assert_eq!(notification.id, Some(expected_id));
1569    }
1570
1571    #[gpui::test]
1572    fn test_deserialize_int_id() {
1573        let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
1574        let notification = serde_json::from_str::<AnyNotification>(json)
1575            .expect("message with string id should be parsed");
1576        let expected_id = RequestId::Int(2);
1577        assert_eq!(notification.id, Some(expected_id));
1578    }
1579
1580    #[test]
1581    fn test_serialize_has_no_nulls() {
1582        // Ensure we're not setting both result and error variants. (ticket #10595)
1583        let no_tag = Response::<u32> {
1584            jsonrpc: "",
1585            id: RequestId::Int(0),
1586            value: LspResult::Ok(None),
1587        };
1588        assert_eq!(
1589            serde_json::to_string(&no_tag).unwrap(),
1590            "{\"jsonrpc\":\"\",\"id\":0,\"result\":null}"
1591        );
1592        let no_tag = Response::<u32> {
1593            jsonrpc: "",
1594            id: RequestId::Int(0),
1595            value: LspResult::Error(None),
1596        };
1597        assert_eq!(
1598            serde_json::to_string(&no_tag).unwrap(),
1599            "{\"jsonrpc\":\"\",\"id\":0,\"error\":null}"
1600        );
1601    }
1602}