lsp: Support tracking multiple registrations of diagnostic providers (#41096)

Piotr Osiewicz , Smit Barmase , Anthony Eid , and dino created

Closes #40966
Closes #41195
Closes #40980 

Release Notes:

- Fixed diagnostics not working with basedpyright/pyright beyond an
initial version of the document

---------

Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: dino <dinojoaocosta@gmail.com>

Change summary

crates/collab/src/tests/editor_tests.rs |   2 
crates/project/src/lsp_command.rs       |  25 +-
crates/project/src/lsp_store.rs         | 302 +++++++++++++++++++-------
3 files changed, 232 insertions(+), 97 deletions(-)

Detailed changes

crates/collab/src/tests/editor_tests.rs 🔗

@@ -2585,7 +2585,7 @@ async fn test_lsp_pull_diagnostics(
             capabilities: capabilities.clone(),
             initializer: Some(Box::new(move |fake_language_server| {
                 let expected_workspace_diagnostic_token = lsp::ProgressToken::String(format!(
-                    "workspace/diagnostic-{}-1",
+                    "workspace/diagnostic/{}/1",
                     fake_language_server.server.server_id()
                 ));
                 let closure_workspace_diagnostics_pulls_result_ids = closure_workspace_diagnostics_pulls_result_ids.clone();

crates/project/src/lsp_command.rs 🔗

@@ -26,8 +26,8 @@ use language::{
 use lsp::{
     AdapterServerCapabilities, CodeActionKind, CodeActionOptions, CodeDescription,
     CompletionContext, CompletionListItemDefaultsEditRange, CompletionTriggerKind,
-    DocumentHighlightKind, LanguageServer, LanguageServerId, LinkedEditingRangeServerCapabilities,
-    OneOf, RenameOptions, ServerCapabilities,
+    DiagnosticServerCapabilities, DocumentHighlightKind, LanguageServer, LanguageServerId,
+    LinkedEditingRangeServerCapabilities, OneOf, RenameOptions, ServerCapabilities,
 };
 use serde_json::Value;
 use signature_help::{lsp_to_proto_signature, proto_to_lsp_signature};
@@ -262,6 +262,9 @@ pub(crate) struct LinkedEditingRange {
 
 #[derive(Clone, Debug)]
 pub(crate) struct GetDocumentDiagnostics {
+    /// We cannot blindly rely on server's capabilities.diagnostic_provider, as they're a singular field, whereas
+    /// a server can register multiple diagnostic providers post-mortem.
+    pub dynamic_caps: DiagnosticServerCapabilities,
     pub previous_result_id: Option<String>,
 }
 
@@ -4031,26 +4034,22 @@ impl LspCommand for GetDocumentDiagnostics {
         "Get diagnostics"
     }
 
-    fn check_capabilities(&self, server_capabilities: AdapterServerCapabilities) -> bool {
-        server_capabilities
-            .server_capabilities
-            .diagnostic_provider
-            .is_some()
+    fn check_capabilities(&self, _: AdapterServerCapabilities) -> bool {
+        true
     }
 
     fn to_lsp(
         &self,
         path: &Path,
         _: &Buffer,
-        language_server: &Arc<LanguageServer>,
+        _: &Arc<LanguageServer>,
         _: &App,
     ) -> Result<lsp::DocumentDiagnosticParams> {
-        let identifier = match language_server.capabilities().diagnostic_provider {
-            Some(lsp::DiagnosticServerCapabilities::Options(options)) => options.identifier,
-            Some(lsp::DiagnosticServerCapabilities::RegistrationOptions(options)) => {
-                options.diagnostic_options.identifier
+        let identifier = match &self.dynamic_caps {
+            lsp::DiagnosticServerCapabilities::Options(options) => options.identifier.clone(),
+            lsp::DiagnosticServerCapabilities::RegistrationOptions(options) => {
+                options.diagnostic_options.identifier.clone()
             }
-            None => None,
         };
 
         Ok(lsp::DocumentDiagnosticParams {

crates/project/src/lsp_store.rs 🔗

@@ -75,12 +75,12 @@ use language::{
     range_from_lsp, range_to_lsp,
 };
 use lsp::{
-    AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticSeverity,
-    DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit, FileOperationFilter,
-    FileOperationPatternKind, FileOperationRegistrationOptions, FileRename, FileSystemWatcher,
-    LSP_REQUEST_TIMEOUT, LanguageServer, LanguageServerBinary, LanguageServerBinaryOptions,
-    LanguageServerId, LanguageServerName, LanguageServerSelector, LspRequestFuture,
-    MessageActionItem, MessageType, OneOf, RenameFilesParams, SymbolKind,
+    AdapterServerCapabilities, CodeActionKind, CompletionContext, DiagnosticServerCapabilities,
+    DiagnosticSeverity, DiagnosticTag, DidChangeWatchedFilesRegistrationOptions, Edit,
+    FileOperationFilter, FileOperationPatternKind, FileOperationRegistrationOptions, FileRename,
+    FileSystemWatcher, LSP_REQUEST_TIMEOUT, LanguageServer, LanguageServerBinary,
+    LanguageServerBinaryOptions, LanguageServerId, LanguageServerName, LanguageServerSelector,
+    LspRequestFuture, MessageActionItem, MessageType, OneOf, RenameFilesParams, SymbolKind,
     TextDocumentSyncSaveOptions, TextEdit, Uri, WillRenameFiles, WorkDoneProgressCancelParams,
     WorkspaceFolder, notification::DidRenameFiles,
 };
@@ -190,6 +190,12 @@ pub struct DocumentDiagnostics {
     version: Option<i32>,
 }
 
+#[derive(Default)]
+struct DynamicRegistrations {
+    did_change_watched_files: HashMap<String, Vec<FileSystemWatcher>>,
+    diagnostics: HashMap<Option<String>, DiagnosticServerCapabilities>,
+}
+
 pub struct LocalLspStore {
     weak: WeakEntity<LspStore>,
     worktree_store: Entity<WorktreeStore>,
@@ -207,8 +213,7 @@ pub struct LocalLspStore {
     watched_manifest_filenames: HashSet<ManifestName>,
     language_server_paths_watched_for_rename:
         HashMap<LanguageServerId, RenamePathsWatchedForServer>,
-    language_server_watcher_registrations:
-        HashMap<LanguageServerId, HashMap<String, Vec<FileSystemWatcher>>>,
+    language_server_dynamic_registrations: HashMap<LanguageServerId, DynamicRegistrations>,
     supplementary_language_servers:
         HashMap<LanguageServerId, (LanguageServerName, Arc<LanguageServer>)>,
     prettier_store: Entity<PrettierStore>,
@@ -3184,7 +3189,7 @@ impl LocalLspStore {
 
         for watcher in watchers {
             if let Some((worktree, literal_prefix, pattern)) =
-                self.worktree_and_path_for_file_watcher(&worktrees, watcher, cx)
+                Self::worktree_and_path_for_file_watcher(&worktrees, watcher, cx)
             {
                 worktree.update(cx, |worktree, _| {
                     if let Some((tree, glob)) =
@@ -3282,7 +3287,6 @@ impl LocalLspStore {
     }
 
     fn worktree_and_path_for_file_watcher(
-        &self,
         worktrees: &[Entity<Worktree>],
         watcher: &FileSystemWatcher,
         cx: &App,
@@ -3330,15 +3334,18 @@ impl LocalLspStore {
         language_server_id: LanguageServerId,
         cx: &mut Context<LspStore>,
     ) {
-        let Some(watchers) = self
-            .language_server_watcher_registrations
+        let Some(registrations) = self
+            .language_server_dynamic_registrations
             .get(&language_server_id)
         else {
             return;
         };
 
-        let watch_builder =
-            self.rebuild_watched_paths_inner(language_server_id, watchers.values().flatten(), cx);
+        let watch_builder = self.rebuild_watched_paths_inner(
+            language_server_id,
+            registrations.did_change_watched_files.values().flatten(),
+            cx,
+        );
         let watcher = watch_builder.build(self.fs.clone(), language_server_id, cx);
         self.language_server_watched_paths
             .insert(language_server_id, watcher);
@@ -3354,11 +3361,13 @@ impl LocalLspStore {
         cx: &mut Context<LspStore>,
     ) {
         let registrations = self
-            .language_server_watcher_registrations
+            .language_server_dynamic_registrations
             .entry(language_server_id)
             .or_default();
 
-        registrations.insert(registration_id.to_string(), params.watchers);
+        registrations
+            .did_change_watched_files
+            .insert(registration_id.to_string(), params.watchers);
 
         self.rebuild_watched_paths(language_server_id, cx);
     }
@@ -3370,11 +3379,15 @@ impl LocalLspStore {
         cx: &mut Context<LspStore>,
     ) {
         let registrations = self
-            .language_server_watcher_registrations
+            .language_server_dynamic_registrations
             .entry(language_server_id)
             .or_default();
 
-        if registrations.remove(registration_id).is_some() {
+        if registrations
+            .did_change_watched_files
+            .remove(registration_id)
+            .is_some()
+        {
             log::info!(
                 "language server {}: unregistered workspace/DidChangeWatchedFiles capability with id {}",
                 language_server_id,
@@ -3782,7 +3795,7 @@ impl LspStore {
                 last_workspace_edits_by_language_server: Default::default(),
                 language_server_watched_paths: Default::default(),
                 language_server_paths_watched_for_rename: Default::default(),
-                language_server_watcher_registrations: Default::default(),
+                language_server_dynamic_registrations: Default::default(),
                 buffers_being_formatted: Default::default(),
                 buffer_snapshots: Default::default(),
                 prettier_store,
@@ -4367,7 +4380,7 @@ impl LspStore {
         cx: &App,
     ) -> bool
     where
-        F: Fn(&lsp::ServerCapabilities) -> bool,
+        F: FnMut(&lsp::ServerCapabilities) -> bool,
     {
         let Some(language) = buffer.read(cx).language().cloned() else {
             return false;
@@ -6447,12 +6460,30 @@ impl LspStore {
         let buffer_id = buffer.read(cx).remote_id();
 
         if let Some((client, upstream_project_id)) = self.upstream_client() {
+            let mut suitable_capabilities = None;
+            // Are we capable for proto request?
+            let any_server_has_diagnostics_provider = self.check_if_capable_for_proto_request(
+                &buffer,
+                |capabilities| {
+                    if let Some(caps) = &capabilities.diagnostic_provider {
+                        suitable_capabilities = Some(caps.clone());
+                        true
+                    } else {
+                        false
+                    }
+                },
+                cx,
+            );
+            // We don't really care which caps are passed into the request, as they're ignored by RPC anyways.
+            let Some(dynamic_caps) = suitable_capabilities else {
+                return Task::ready(Ok(None));
+            };
+            assert!(any_server_has_diagnostics_provider);
+
             let request = GetDocumentDiagnostics {
                 previous_result_id: None,
+                dynamic_caps,
             };
-            if !self.is_capable_for_proto_request(&buffer, &request, cx) {
-                return Task::ready(Ok(None));
-            }
             let request_task = client.request_lsp(
                 upstream_project_id,
                 None,
@@ -6468,23 +6499,44 @@ impl LspStore {
                 Ok(None)
             })
         } else {
-            let server_ids = buffer.update(cx, |buffer, cx| {
+            let servers = buffer.update(cx, |buffer, cx| {
                 self.language_servers_for_local_buffer(buffer, cx)
-                    .map(|(_, server)| server.server_id())
+                    .map(|(_, server)| server.clone())
                     .collect::<Vec<_>>()
             });
-            let pull_diagnostics = server_ids
+
+            let pull_diagnostics = servers
                 .into_iter()
-                .map(|server_id| {
-                    let result_id = self.result_id(server_id, buffer_id, cx);
-                    self.request_lsp(
-                        buffer.clone(),
-                        LanguageServerToQuery::Other(server_id),
-                        GetDocumentDiagnostics {
-                            previous_result_id: result_id,
-                        },
-                        cx,
-                    )
+                .flat_map(|server| {
+                    let result = maybe!({
+                        let local = self.as_local()?;
+                        let server_id = server.server_id();
+                        let providers_with_identifiers = local
+                            .language_server_dynamic_registrations
+                            .get(&server_id)
+                            .into_iter()
+                            .flat_map(|registrations| registrations.diagnostics.values().cloned())
+                            .collect::<Vec<_>>();
+                        Some(
+                            providers_with_identifiers
+                                .into_iter()
+                                .map(|dynamic_caps| {
+                                    let result_id = self.result_id(server_id, buffer_id, cx);
+                                    self.request_lsp(
+                                        buffer.clone(),
+                                        LanguageServerToQuery::Other(server_id),
+                                        GetDocumentDiagnostics {
+                                            previous_result_id: result_id,
+                                            dynamic_caps,
+                                        },
+                                        cx,
+                                    )
+                                })
+                                .collect::<Vec<_>>(),
+                        )
+                    });
+
+                    result.unwrap_or_default()
                 })
                 .collect::<Vec<_>>();
 
@@ -9322,14 +9374,17 @@ impl LspStore {
                 );
             }
             lsp::ProgressParamsValue::WorkspaceDiagnostic(report) => {
+                let identifier = token.split_once("id:").map(|(_, id)| id.to_owned());
                 if let Some(LanguageServerState::Running {
-                    workspace_refresh_task: Some(workspace_refresh_task),
+                    workspace_diagnostics_refresh_tasks,
                     ..
                 }) = self
                     .as_local_mut()
                     .and_then(|local| local.language_servers.get_mut(&language_server_id))
+                    && let Some(workspace_diagnostics) =
+                        workspace_diagnostics_refresh_tasks.get_mut(&identifier)
                 {
-                    workspace_refresh_task.progress_tx.try_send(()).ok();
+                    workspace_diagnostics.progress_tx.try_send(()).ok();
                     self.apply_workspace_diagnostic_report(language_server_id, report, cx)
                 }
             }
@@ -10784,13 +10839,31 @@ impl LspStore {
         let workspace_folders = workspace_folders.lock().clone();
         language_server.set_workspace_folders(workspace_folders);
 
+        let workspace_diagnostics_refresh_tasks = language_server
+            .capabilities()
+            .diagnostic_provider
+            .and_then(|provider| {
+                let workspace_refresher = lsp_workspace_diagnostics_refresh(
+                    None,
+                    provider.clone(),
+                    language_server.clone(),
+                    cx,
+                )?;
+                local
+                    .language_server_dynamic_registrations
+                    .entry(server_id)
+                    .or_default()
+                    .diagnostics
+                    .entry(None)
+                    .or_insert(provider);
+                Some((None, workspace_refresher))
+            })
+            .into_iter()
+            .collect();
         local.language_servers.insert(
             server_id,
             LanguageServerState::Running {
-                workspace_refresh_task: lsp_workspace_diagnostics_refresh(
-                    language_server.clone(),
-                    cx,
-                ),
+                workspace_diagnostics_refresh_tasks,
                 adapter: adapter.clone(),
                 server: language_server.clone(),
                 simulate_disk_based_diagnostics_completion: None,
@@ -11495,13 +11568,15 @@ impl LspStore {
 
     pub fn pull_workspace_diagnostics(&mut self, server_id: LanguageServerId) {
         if let Some(LanguageServerState::Running {
-            workspace_refresh_task: Some(workspace_refresh_task),
+            workspace_diagnostics_refresh_tasks,
             ..
         }) = self
             .as_local_mut()
             .and_then(|local| local.language_servers.get_mut(&server_id))
         {
-            workspace_refresh_task.refresh_tx.try_send(()).ok();
+            for diagnostics in workspace_diagnostics_refresh_tasks.values_mut() {
+                diagnostics.refresh_tx.try_send(()).ok();
+            }
         }
     }
 
@@ -11517,11 +11592,13 @@ impl LspStore {
             local.language_server_ids_for_buffer(buffer, cx)
         }) {
             if let Some(LanguageServerState::Running {
-                workspace_refresh_task: Some(workspace_refresh_task),
+                workspace_diagnostics_refresh_tasks,
                 ..
             }) = local.language_servers.get_mut(&server_id)
             {
-                workspace_refresh_task.refresh_tx.try_send(()).ok();
+                for diagnostics in workspace_diagnostics_refresh_tasks.values_mut() {
+                    diagnostics.refresh_tx.try_send(()).ok();
+                }
             }
         }
     }
@@ -11847,26 +11924,49 @@ impl LspStore {
                 "textDocument/diagnostic" => {
                     if let Some(caps) = reg
                         .register_options
-                        .map(serde_json::from_value)
+                        .map(serde_json::from_value::<DiagnosticServerCapabilities>)
                         .transpose()?
                     {
-                        let state = self
+                        let local = self
                             .as_local_mut()
-                            .context("Expected LSP Store to be local")?
+                            .context("Expected LSP Store to be local")?;
+                        let state = local
                             .language_servers
                             .get_mut(&server_id)
                             .context("Could not obtain Language Servers state")?;
-                        server.update_capabilities(|capabilities| {
-                            capabilities.diagnostic_provider = Some(caps);
-                        });
+                        local
+                            .language_server_dynamic_registrations
+                            .get_mut(&server_id)
+                            .and_then(|registrations| {
+                                registrations
+                                    .diagnostics
+                                    .insert(Some(reg.id.clone()), caps.clone())
+                            });
+
+                        let mut can_now_provide_diagnostics = false;
                         if let LanguageServerState::Running {
-                            workspace_refresh_task,
+                            workspace_diagnostics_refresh_tasks,
                             ..
                         } = state
-                            && workspace_refresh_task.is_none()
+                            && let Some(task) = lsp_workspace_diagnostics_refresh(
+                                Some(reg.id.clone()),
+                                caps.clone(),
+                                server.clone(),
+                                cx,
+                            )
                         {
-                            *workspace_refresh_task =
-                                lsp_workspace_diagnostics_refresh(server.clone(), cx)
+                            workspace_diagnostics_refresh_tasks.insert(Some(reg.id), task);
+                            can_now_provide_diagnostics = true;
+                        }
+
+                        // We don't actually care about capabilities.diagnostic_provider, but it IS relevant for the remote peer
+                        // to know that there's at least one provider. Otherwise, it will never ask us to issue documentdiagnostic calls on their behalf,
+                        // as it'll think that they're not supported.
+                        if can_now_provide_diagnostics {
+                            server.update_capabilities(|capabilities| {
+                                debug_assert!(capabilities.diagnostic_provider.is_none());
+                                capabilities.diagnostic_provider = Some(caps);
+                            });
                         }
 
                         notify_server_capabilities_updated(&server, cx);
@@ -12029,22 +12129,45 @@ impl LspStore {
                     notify_server_capabilities_updated(&server, cx);
                 }
                 "textDocument/diagnostic" => {
-                    server.update_capabilities(|capabilities| {
-                        capabilities.diagnostic_provider = None;
-                    });
-                    let state = self
+                    let local = self
                         .as_local_mut()
-                        .context("Expected LSP Store to be local")?
+                        .context("Expected LSP Store to be local")?;
+
+                    let state = local
                         .language_servers
                         .get_mut(&server_id)
                         .context("Could not obtain Language Servers state")?;
-                    if let LanguageServerState::Running {
-                        workspace_refresh_task,
-                        ..
-                    } = state
+                    let options = local
+                        .language_server_dynamic_registrations
+                        .get_mut(&server_id)
+                        .with_context(|| {
+                            format!("Expected dynamic registration to exist for server {server_id}")
+                        })?.diagnostics
+                        .remove(&Some(unreg.id.clone()))
+                        .with_context(|| format!(
+                            "Attempted to unregister non-existent diagnostic registration with ID {}",
+                            unreg.id)
+                        )?;
+
+                    let mut has_any_diagnostic_providers_still = true;
+                    if let Some(identifier) = diagnostic_identifier(&options)
+                        && let LanguageServerState::Running {
+                            workspace_diagnostics_refresh_tasks,
+                            ..
+                        } = state
                     {
-                        _ = workspace_refresh_task.take();
+                        workspace_diagnostics_refresh_tasks.remove(&identifier);
+                        has_any_diagnostic_providers_still =
+                            !workspace_diagnostics_refresh_tasks.is_empty();
                     }
+
+                    if !has_any_diagnostic_providers_still {
+                        server.update_capabilities(|capabilities| {
+                            debug_assert!(capabilities.diagnostic_provider.is_some());
+                            capabilities.diagnostic_provider = None;
+                        });
+                    }
+
                     notify_server_capabilities_updated(&server, cx);
                 }
                 "textDocument/documentColor" => {
@@ -12333,24 +12456,12 @@ fn subscribe_to_binary_statuses(
 }
 
 fn lsp_workspace_diagnostics_refresh(
+    registration_id: Option<String>,
+    options: DiagnosticServerCapabilities,
     server: Arc<LanguageServer>,
     cx: &mut Context<'_, LspStore>,
 ) -> Option<WorkspaceRefreshTask> {
-    let identifier = match server.capabilities().diagnostic_provider? {
-        lsp::DiagnosticServerCapabilities::Options(diagnostic_options) => {
-            if !diagnostic_options.workspace_diagnostics {
-                return None;
-            }
-            diagnostic_options.identifier
-        }
-        lsp::DiagnosticServerCapabilities::RegistrationOptions(registration_options) => {
-            let diagnostic_options = registration_options.diagnostic_options;
-            if !diagnostic_options.workspace_diagnostics {
-                return None;
-            }
-            diagnostic_options.identifier
-        }
-    };
+    let identifier = diagnostic_identifier(&options)?;
 
     let (progress_tx, mut progress_rx) = mpsc::channel(1);
     let (mut refresh_tx, mut refresh_rx) = mpsc::channel(1);
@@ -12396,7 +12507,14 @@ fn lsp_workspace_diagnostics_refresh(
                     return;
                 };
 
-                let token = format!("workspace/diagnostic-{}-{}", server.server_id(), requests);
+                let token = if let Some(identifier) = &registration_id {
+                    format!(
+                        "workspace/diagnostic/{}/{requests}/id:{identifier}",
+                        server.server_id(),
+                    )
+                } else {
+                    format!("workspace/diagnostic/{}/{requests}", server.server_id())
+                };
 
                 progress_rx.try_recv().ok();
                 let timer =
@@ -12462,6 +12580,24 @@ fn lsp_workspace_diagnostics_refresh(
     })
 }
 
+fn diagnostic_identifier(options: &DiagnosticServerCapabilities) -> Option<Option<String>> {
+    match &options {
+        lsp::DiagnosticServerCapabilities::Options(diagnostic_options) => {
+            if !diagnostic_options.workspace_diagnostics {
+                return None;
+            }
+            Some(diagnostic_options.identifier.clone())
+        }
+        lsp::DiagnosticServerCapabilities::RegistrationOptions(registration_options) => {
+            let diagnostic_options = &registration_options.diagnostic_options;
+            if !diagnostic_options.workspace_diagnostics {
+                return None;
+            }
+            Some(diagnostic_options.identifier.clone())
+        }
+    }
+}
+
 fn resolve_word_completion(snapshot: &BufferSnapshot, completion: &mut Completion) {
     let CompletionSource::BufferWord {
         word_range,
@@ -12866,7 +13002,7 @@ pub enum LanguageServerState {
         adapter: Arc<CachedLspAdapter>,
         server: Arc<LanguageServer>,
         simulate_disk_based_diagnostics_completion: Option<Task<()>>,
-        workspace_refresh_task: Option<WorkspaceRefreshTask>,
+        workspace_diagnostics_refresh_tasks: HashMap<Option<String>, WorkspaceRefreshTask>,
     },
 }