Always group diagnostics the way they're grouped in the LSP message

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/project/src/worktree.rs | 147 +++++++++++++++++++----------------
1 file changed, 79 insertions(+), 68 deletions(-)

Detailed changes

crates/project/src/worktree.rs 🔗

@@ -34,7 +34,6 @@ use std::{
     ffi::{OsStr, OsString},
     fmt,
     future::Future,
-    mem,
     ops::{Deref, Range},
     path::{Path, PathBuf},
     sync::{
@@ -691,7 +690,7 @@ impl Worktree {
 
     pub fn update_diagnostics(
         &mut self,
-        mut params: lsp::PublishDiagnosticsParams,
+        params: lsp::PublishDiagnosticsParams,
         disk_based_sources: &HashSet<String>,
         cx: &mut ModelContext<Worktree>,
     ) -> Result<()> {
@@ -706,58 +705,92 @@ impl Worktree {
                 .context("path is not within worktree")?,
         );
 
-        let mut group_ids_by_diagnostic_range = HashMap::default();
-        let mut diagnostics_by_group_id = HashMap::default();
         let mut next_group_id = 0;
-        for diagnostic in &mut params.diagnostics {
+        let mut diagnostics = Vec::default();
+        let mut primary_diagnostic_group_ids = HashMap::default();
+        let mut sources_by_group_id = HashMap::default();
+        let mut supporting_diagnostic_severities = HashMap::default();
+        for diagnostic in &params.diagnostics {
             let source = diagnostic.source.as_ref();
-            let code = diagnostic.code.as_ref();
-            let group_id = diagnostic_ranges(&diagnostic, &abs_path)
-                .find_map(|range| group_ids_by_diagnostic_range.get(&(source, code, range)))
-                .copied()
-                .unwrap_or_else(|| {
-                    let group_id = post_inc(&mut next_group_id);
-                    for range in diagnostic_ranges(&diagnostic, &abs_path) {
-                        group_ids_by_diagnostic_range.insert((source, code, range), group_id);
-                    }
-                    group_id
+            let code = diagnostic.code.as_ref().map(|code| match code {
+                lsp::NumberOrString::Number(code) => code.to_string(),
+                lsp::NumberOrString::String(code) => code.clone(),
+            });
+            let range = range_from_lsp(diagnostic.range);
+            let is_supporting = diagnostic
+                .related_information
+                .as_ref()
+                .map_or(false, |infos| {
+                    infos.iter().any(|info| {
+                        primary_diagnostic_group_ids.contains_key(&(
+                            source,
+                            code.clone(),
+                            range_from_lsp(info.location.range),
+                        ))
+                    })
                 });
 
-            diagnostics_by_group_id
-                .entry(group_id)
-                .or_insert(Vec::new())
-                .push(DiagnosticEntry {
-                    range: diagnostic.range.start.to_point_utf16()
-                        ..diagnostic.range.end.to_point_utf16(),
+            if is_supporting {
+                if let Some(severity) = diagnostic.severity {
+                    supporting_diagnostic_severities
+                        .insert((source, code.clone(), range), severity);
+                }
+            } else {
+                let group_id = post_inc(&mut next_group_id);
+                let is_disk_based =
+                    source.map_or(false, |source| disk_based_sources.contains(source));
+
+                sources_by_group_id.insert(group_id, source);
+                primary_diagnostic_group_ids
+                    .insert((source, code.clone(), range.clone()), group_id);
+
+                diagnostics.push(DiagnosticEntry {
+                    range,
                     diagnostic: Diagnostic {
-                        code: diagnostic.code.clone().map(|code| match code {
-                            lsp::NumberOrString::Number(code) => code.to_string(),
-                            lsp::NumberOrString::String(code) => code,
-                        }),
+                        code: code.clone(),
                         severity: diagnostic.severity.unwrap_or(DiagnosticSeverity::ERROR),
-                        message: mem::take(&mut diagnostic.message),
+                        message: diagnostic.message.clone(),
                         group_id,
-                        is_primary: false,
+                        is_primary: true,
                         is_valid: true,
-                        is_disk_based: diagnostic
-                            .source
-                            .as_ref()
-                            .map_or(false, |source| disk_based_sources.contains(source)),
+                        is_disk_based,
                     },
                 });
+                if let Some(infos) = &diagnostic.related_information {
+                    for info in infos {
+                        if info.location.uri == params.uri {
+                            let range = range_from_lsp(info.location.range);
+                            diagnostics.push(DiagnosticEntry {
+                                range,
+                                diagnostic: Diagnostic {
+                                    code: code.clone(),
+                                    severity: DiagnosticSeverity::INFORMATION,
+                                    message: info.message.clone(),
+                                    group_id,
+                                    is_primary: false,
+                                    is_valid: true,
+                                    is_disk_based,
+                                },
+                            });
+                        }
+                    }
+                }
+            }
         }
 
-        let diagnostics = diagnostics_by_group_id
-            .into_values()
-            .flat_map(|mut diagnostics| {
-                let primary = diagnostics
-                    .iter_mut()
-                    .min_by_key(|entry| entry.diagnostic.severity)
-                    .unwrap();
-                primary.diagnostic.is_primary = true;
-                diagnostics
-            })
-            .collect::<Vec<_>>();
+        for entry in &mut diagnostics {
+            let diagnostic = &mut entry.diagnostic;
+            if !diagnostic.is_primary {
+                let source = *sources_by_group_id.get(&diagnostic.group_id).unwrap();
+                if let Some(&severity) = supporting_diagnostic_severities.get(&(
+                    source,
+                    diagnostic.code.clone(),
+                    entry.range.clone(),
+                )) {
+                    diagnostic.severity = severity;
+                }
+            }
+        }
 
         self.update_diagnostic_entries(worktree_path, params.version, diagnostics, cx)?;
         Ok(())
@@ -3103,32 +3136,10 @@ impl ToPointUtf16 for lsp::Position {
     }
 }
 
-fn diagnostic_ranges<'a>(
-    diagnostic: &'a lsp::Diagnostic,
-    abs_path: &'a Path,
-) -> impl 'a + Iterator<Item = Range<PointUtf16>> {
-    diagnostic
-        .related_information
-        .iter()
-        .flatten()
-        .filter_map(move |info| {
-            if info.location.uri.to_file_path().ok()? == abs_path {
-                let info_start = PointUtf16::new(
-                    info.location.range.start.line,
-                    info.location.range.start.character,
-                );
-                let info_end = PointUtf16::new(
-                    info.location.range.end.line,
-                    info.location.range.end.character,
-                );
-                Some(info_start..info_end)
-            } else {
-                None
-            }
-        })
-        .chain(Some(
-            diagnostic.range.start.to_point_utf16()..diagnostic.range.end.to_point_utf16(),
-        ))
+fn range_from_lsp(range: lsp::Range) -> Range<PointUtf16> {
+    let start = PointUtf16::new(range.start.line, range.start.character);
+    let end = PointUtf16::new(range.end.line, range.end.character);
+    start..end
 }
 
 #[cfg(test)]