Delete pulled diagnostics when the source registration is unregistered (#46105)

John Tur created

Additionally, fix a race condition where we'd still insert diagnostics
from a document or workspace pull even if the registration had been
unregistered in the time since the request was issued.

And, as a bonus: when a new pull diagnostics registration is added,
issue document pulls immediately.

This should fix regressions with basedpyright caused by
https://github.com/zed-industries/zed/pull/43703.

Release Notes:

- N/A

Change summary

crates/project/src/lsp_store.rs | 133 ++++++++++++++++++++++++++++++++++
1 file changed, 132 insertions(+), 1 deletion(-)

Detailed changes

crates/project/src/lsp_store.rs 🔗

@@ -7263,6 +7263,22 @@ impl LspStore {
         }
     }
 
+    fn diagnostic_registration_exists(
+        &self,
+        server_id: LanguageServerId,
+        registration_id: &Option<SharedString>,
+    ) -> bool {
+        let Some(local) = self.as_local() else {
+            return false;
+        };
+        let Some(registrations) = local.language_server_dynamic_registrations.get(&server_id)
+        else {
+            return false;
+        };
+        let registration_key = registration_id.as_ref().map(|s| s.to_string());
+        registrations.diagnostics.contains_key(&registration_key)
+    }
+
     pub fn pull_diagnostics_for_buffer(
         &mut self,
         buffer: Entity<Buffer>,
@@ -7290,6 +7306,9 @@ impl LspStore {
                         } => Some((server_id, uri, diagnostics, registration_id)),
                         LspPullDiagnostics::Default => None,
                     })
+                    .filter(|(server_id, _, _, registration_id)| {
+                        lsp_store.diagnostic_registration_exists(*server_id, registration_id)
+                    })
                     .fold(
                         HashMap::default(),
                         |mut acc, (server_id, uri, diagnostics, new_registration_id)| {
@@ -12194,12 +12213,20 @@ impl LspStore {
         registration_id: Option<SharedString>,
         cx: &mut Context<Self>,
     ) {
-        let workspace_diagnostics =
+        let mut workspace_diagnostics =
             GetDocumentDiagnostics::deserialize_workspace_diagnostics_report(
                 report,
                 server_id,
                 registration_id,
             );
+        workspace_diagnostics.retain(|d| match &d.diagnostics {
+            LspPullDiagnostics::Response {
+                server_id,
+                registration_id,
+                ..
+            } => self.diagnostic_registration_exists(*server_id, registration_id),
+            LspPullDiagnostics::Default => false,
+        });
         let mut unchanged_buffers = HashMap::default();
         let workspace_diagnostics_updates = workspace_diagnostics
             .into_iter()
@@ -12618,6 +12645,23 @@ impl LspStore {
                         });
 
                         notify_server_capabilities_updated(&server, cx);
+
+                        let buffers_to_pull: Vec<_> = self
+                            .as_local()
+                            .into_iter()
+                            .flat_map(|local| {
+                                self.buffer_store.read(cx).buffers().filter(|buffer| {
+                                    let buffer_id = buffer.read(cx).remote_id();
+                                    local
+                                        .buffers_opened_in_servers
+                                        .get(&buffer_id)
+                                        .is_some_and(|servers| servers.contains(&server_id))
+                                })
+                            })
+                            .collect();
+                        for buffer in buffers_to_pull {
+                            self.pull_diagnostics_for_buffer(buffer, cx).detach();
+                        }
                     }
                 }
                 "textDocument/documentColor" => {
@@ -12807,6 +12851,12 @@ impl LspStore {
                         workspace_diagnostics_refresh_tasks.remove(&Some(unreg.id.clone()));
                     }
 
+                    self.clear_unregistered_diagnostics(
+                        server_id,
+                        SharedString::from(unreg.id.clone()),
+                        cx,
+                    )?;
+
                     if removed_last_diagnostic_provider {
                         server.update_capabilities(|capabilities| {
                             debug_assert!(capabilities.diagnostic_provider.is_some());
@@ -12829,6 +12879,87 @@ impl LspStore {
         Ok(())
     }
 
+    fn clear_unregistered_diagnostics(
+        &mut self,
+        server_id: LanguageServerId,
+        cleared_registration_id: SharedString,
+        cx: &mut Context<Self>,
+    ) -> anyhow::Result<()> {
+        let mut affected_abs_paths: HashSet<PathBuf> = HashSet::default();
+
+        self.buffer_store.update(cx, |buffer_store, cx| {
+            for buffer_handle in buffer_store.buffers() {
+                let buffer = buffer_handle.read(cx);
+                let abs_path = File::from_dyn(buffer.file()).map(|f| f.abs_path(cx));
+                let Some(abs_path) = abs_path else {
+                    continue;
+                };
+                affected_abs_paths.insert(abs_path);
+            }
+        });
+
+        let local = self.as_local().context("Expected LSP Store to be local")?;
+        for (worktree_id, diagnostics_for_tree) in local.diagnostics.iter() {
+            let Some(worktree) = self
+                .worktree_store
+                .read(cx)
+                .worktree_for_id(*worktree_id, cx)
+            else {
+                continue;
+            };
+
+            for (rel_path, diagnostics_by_server_id) in diagnostics_for_tree.iter() {
+                if let Ok(ix) = diagnostics_by_server_id.binary_search_by_key(&server_id, |e| e.0) {
+                    let has_matching_registration =
+                        diagnostics_by_server_id[ix].1.iter().any(|entry| {
+                            entry.diagnostic.registration_id.as_ref()
+                                == Some(&cleared_registration_id)
+                        });
+                    if has_matching_registration {
+                        let abs_path = worktree.read(cx).absolutize(rel_path);
+                        affected_abs_paths.insert(abs_path);
+                    }
+                }
+            }
+        }
+
+        if affected_abs_paths.is_empty() {
+            return Ok(());
+        }
+
+        // Send a fake diagnostic update which clears the state for the registration ID
+        let clears: Vec<DocumentDiagnosticsUpdate<'static, DocumentDiagnostics>> =
+            affected_abs_paths
+                .into_iter()
+                .map(|abs_path| DocumentDiagnosticsUpdate {
+                    diagnostics: DocumentDiagnostics {
+                        diagnostics: Vec::new(),
+                        document_abs_path: abs_path,
+                        version: None,
+                    },
+                    result_id: None,
+                    registration_id: Some(cleared_registration_id.clone()),
+                    server_id,
+                    disk_based_sources: Cow::Borrowed(&[]),
+                })
+                .collect();
+
+        let merge_registration_id = cleared_registration_id.clone();
+        self.merge_diagnostic_entries(
+            clears,
+            move |_, diagnostic, _| {
+                if diagnostic.source_kind == DiagnosticSourceKind::Pulled {
+                    diagnostic.registration_id != Some(merge_registration_id.clone())
+                } else {
+                    true
+                }
+            },
+            cx,
+        )?;
+
+        Ok(())
+    }
+
     async fn deduplicate_range_based_lsp_requests<T>(
         lsp_store: &Entity<Self>,
         server_id: Option<LanguageServerId>,