diagnostic_set.rs

  1use crate::{Diagnostic, range_to_lsp};
  2use anyhow::Result;
  3use collections::HashMap;
  4use lsp::LanguageServerId;
  5use serde::Serialize;
  6use std::{
  7    cmp::{Ordering, Reverse},
  8    iter,
  9    ops::Range,
 10};
 11use sum_tree::{self, Bias, SumTree};
 12use text::{Anchor, FromAnchor, PointUtf16, ToOffset};
 13
 14/// A set of diagnostics associated with a given buffer, provided
 15/// by a single language server.
 16///
 17/// The diagnostics are stored in a [`SumTree`], which allows this struct
 18/// to be cheaply copied, and allows for efficient retrieval of the
 19/// diagnostics that intersect a given range of the buffer.
 20#[derive(Clone, Debug)]
 21pub struct DiagnosticSet {
 22    diagnostics: SumTree<DiagnosticEntry<Anchor>>,
 23}
 24
 25/// A single diagnostic in a set. Generic over its range type, because
 26/// the diagnostics are stored internally as [`Anchor`]s, but can be
 27/// resolved to different coordinates types like [`usize`] byte offsets or
 28/// [`Point`](gpui::Point)s.
 29#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
 30pub struct DiagnosticEntry<T> {
 31    /// The range of the buffer where the diagnostic applies.
 32    pub range: Range<T>,
 33    /// The information about the diagnostic.
 34    pub diagnostic: Diagnostic,
 35}
 36
 37/// A group of related diagnostics, ordered by their start position
 38/// in the buffer.
 39#[derive(Debug, Serialize)]
 40pub struct DiagnosticGroup<T> {
 41    /// The diagnostics.
 42    pub entries: Vec<DiagnosticEntry<T>>,
 43    /// The index into `entries` where the primary diagnostic is stored.
 44    pub primary_ix: usize,
 45}
 46
 47impl DiagnosticGroup<Anchor> {
 48    /// Converts the entries in this [`DiagnosticGroup`] to a different buffer coordinate type.
 49    pub fn resolve<O: FromAnchor>(&self, buffer: &text::BufferSnapshot) -> DiagnosticGroup<O> {
 50        DiagnosticGroup {
 51            entries: self
 52                .entries
 53                .iter()
 54                .map(|entry| entry.resolve(buffer))
 55                .collect(),
 56            primary_ix: self.primary_ix,
 57        }
 58    }
 59}
 60
 61#[derive(Clone, Debug)]
 62pub struct Summary {
 63    start: Anchor,
 64    end: Anchor,
 65    min_start: Anchor,
 66    max_end: Anchor,
 67    count: usize,
 68}
 69
 70impl DiagnosticEntry<PointUtf16> {
 71    /// Returns a raw LSP diagnostic used to provide diagnostic context to LSP
 72    /// codeAction request
 73    pub fn to_lsp_diagnostic_stub(&self) -> Result<lsp::Diagnostic> {
 74        let range = range_to_lsp(self.range.clone())?;
 75
 76        Ok(lsp::Diagnostic {
 77            range,
 78            code: self.diagnostic.code.clone(),
 79            severity: Some(self.diagnostic.severity),
 80            source: self.diagnostic.source.clone(),
 81            message: self.diagnostic.message.clone(),
 82            data: self.diagnostic.data.clone(),
 83            ..Default::default()
 84        })
 85    }
 86}
 87
 88impl DiagnosticSet {
 89    /// Constructs a [DiagnosticSet] from a sequence of entries, ordered by
 90    /// their position in the buffer.
 91    pub fn from_sorted_entries<I>(iter: I, buffer: &text::BufferSnapshot) -> Self
 92    where
 93        I: IntoIterator<Item = DiagnosticEntry<Anchor>>,
 94    {
 95        Self {
 96            diagnostics: SumTree::from_iter(iter, buffer),
 97        }
 98    }
 99
100    /// Constructs a [DiagnosticSet] from a sequence of entries in an arbitrary order.
101    pub fn new<I>(iter: I, buffer: &text::BufferSnapshot) -> Self
102    where
103        I: IntoIterator<Item = DiagnosticEntry<PointUtf16>>,
104    {
105        let mut entries = iter.into_iter().collect::<Vec<_>>();
106        entries.sort_unstable_by_key(|entry| (entry.range.start, Reverse(entry.range.end)));
107        Self {
108            diagnostics: SumTree::from_iter(
109                entries.into_iter().map(|entry| DiagnosticEntry {
110                    range: buffer.anchor_before(entry.range.start)
111                        ..buffer.anchor_before(entry.range.end),
112                    diagnostic: entry.diagnostic,
113                }),
114                buffer,
115            ),
116        }
117    }
118
119    /// Returns the number of diagnostics in the set.
120    pub fn len(&self) -> usize {
121        self.diagnostics.summary().count
122    }
123    /// Returns true when there are no diagnostics in this diagnostic set
124    pub fn is_empty(&self) -> bool {
125        self.len() == 0
126    }
127
128    /// Returns an iterator over the diagnostic entries in the set.
129    pub fn iter(&self) -> impl Iterator<Item = &DiagnosticEntry<Anchor>> {
130        self.diagnostics.iter()
131    }
132
133    /// Returns an iterator over the diagnostic entries that intersect the
134    /// given range of the buffer.
135    pub fn range<'a, T, O>(
136        &'a self,
137        range: Range<T>,
138        buffer: &'a text::BufferSnapshot,
139        inclusive: bool,
140        reversed: bool,
141    ) -> impl 'a + Iterator<Item = DiagnosticEntry<O>>
142    where
143        T: 'a + ToOffset,
144        O: FromAnchor,
145    {
146        let end_bias = if inclusive { Bias::Right } else { Bias::Left };
147        let range = buffer.anchor_before(range.start)..buffer.anchor_at(range.end, end_bias);
148        let mut cursor = self.diagnostics.filter::<_, ()>(buffer, {
149            move |summary: &Summary| {
150                let start_cmp = range.start.cmp(&summary.max_end, buffer);
151                let end_cmp = range.end.cmp(&summary.min_start, buffer);
152                if inclusive {
153                    start_cmp <= Ordering::Equal && end_cmp >= Ordering::Equal
154                } else {
155                    start_cmp == Ordering::Less && end_cmp == Ordering::Greater
156                }
157            }
158        });
159
160        if reversed {
161            cursor.prev();
162        } else {
163            cursor.next();
164        }
165        iter::from_fn({
166            move || {
167                if let Some(diagnostic) = cursor.item() {
168                    if reversed {
169                        cursor.prev();
170                    } else {
171                        cursor.next();
172                    }
173                    Some(diagnostic.resolve(buffer))
174                } else {
175                    None
176                }
177            }
178        })
179    }
180
181    /// Adds all of this set's diagnostic groups to the given output vector.
182    pub fn groups(
183        &self,
184        language_server_id: LanguageServerId,
185        output: &mut Vec<(LanguageServerId, DiagnosticGroup<Anchor>)>,
186        buffer: &text::BufferSnapshot,
187    ) {
188        let mut groups = HashMap::default();
189        for entry in self.diagnostics.iter() {
190            groups
191                .entry(entry.diagnostic.group_id)
192                .or_insert(Vec::new())
193                .push(entry.clone());
194        }
195
196        let start_ix = output.len();
197        output.extend(groups.into_values().filter_map(|mut entries| {
198            entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start, buffer));
199            entries
200                .iter()
201                .position(|entry| entry.diagnostic.is_primary)
202                .map(|primary_ix| {
203                    (
204                        language_server_id,
205                        DiagnosticGroup {
206                            entries,
207                            primary_ix,
208                        },
209                    )
210                })
211        }));
212        output[start_ix..].sort_unstable_by(|(id_a, group_a), (id_b, group_b)| {
213            group_a.entries[group_a.primary_ix]
214                .range
215                .start
216                .cmp(&group_b.entries[group_b.primary_ix].range.start, buffer)
217                .then_with(|| id_a.cmp(id_b))
218        });
219    }
220
221    /// Returns all of the diagnostics in a particular diagnostic group,
222    /// in order of their position in the buffer.
223    pub fn group<'a, O: FromAnchor>(
224        &'a self,
225        group_id: usize,
226        buffer: &'a text::BufferSnapshot,
227    ) -> impl 'a + Iterator<Item = DiagnosticEntry<O>> {
228        self.iter()
229            .filter(move |entry| entry.diagnostic.group_id == group_id)
230            .map(|entry| entry.resolve(buffer))
231    }
232}
233
234impl sum_tree::Item for DiagnosticEntry<Anchor> {
235    type Summary = Summary;
236
237    fn summary(&self, _cx: &text::BufferSnapshot) -> Self::Summary {
238        Summary {
239            start: self.range.start,
240            end: self.range.end,
241            min_start: self.range.start,
242            max_end: self.range.end,
243            count: 1,
244        }
245    }
246}
247
248impl DiagnosticEntry<Anchor> {
249    /// Converts the [DiagnosticEntry] to a different buffer coordinate type.
250    pub fn resolve<O: FromAnchor>(&self, buffer: &text::BufferSnapshot) -> DiagnosticEntry<O> {
251        DiagnosticEntry {
252            range: O::from_anchor(&self.range.start, buffer)
253                ..O::from_anchor(&self.range.end, buffer),
254            diagnostic: self.diagnostic.clone(),
255        }
256    }
257}
258
259impl Default for Summary {
260    fn default() -> Self {
261        Self {
262            start: Anchor::MIN,
263            end: Anchor::MAX,
264            min_start: Anchor::MAX,
265            max_end: Anchor::MIN,
266            count: 0,
267        }
268    }
269}
270
271impl sum_tree::Summary for Summary {
272    type Context = text::BufferSnapshot;
273
274    fn zero(_cx: &Self::Context) -> Self {
275        Default::default()
276    }
277
278    fn add_summary(&mut self, other: &Self, buffer: &Self::Context) {
279        if other.min_start.cmp(&self.min_start, buffer).is_lt() {
280            self.min_start = other.min_start;
281        }
282        if other.max_end.cmp(&self.max_end, buffer).is_gt() {
283            self.max_end = other.max_end;
284        }
285        self.start = other.start;
286        self.end = other.end;
287        self.count += other.count;
288    }
289}