language: Make `TreeSitterData` only shared between snapshots of the same version (#44198)

Lukas Wirth created

Currently we have a single cache for this data shared between all
snapshots which is incorrect, as we might update the cache to a new
version while having old snapshots around which then may try to access
new data with old offsets/rows.

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/editor/src/bracket_colorization.rs |   4 
crates/language/src/buffer.rs             | 273 ++++++++++++------------
crates/language/src/buffer/row_chunk.rs   |   2 
crates/text/src/text.rs                   |   9 
4 files changed, 145 insertions(+), 143 deletions(-)

Detailed changes

crates/editor/src/bracket_colorization.rs 🔗

@@ -45,7 +45,7 @@ impl Editor {
 
         let bracket_matches_by_accent = self.visible_excerpts(false, cx).into_iter().fold(
             HashMap::default(),
-            |mut acc, (excerpt_id, (buffer, buffer_version, buffer_range))| {
+            |mut acc, (excerpt_id, (buffer, _, buffer_range))| {
                 let buffer_snapshot = buffer.read(cx).snapshot();
                 if language_settings::language_settings(
                     buffer_snapshot.language().map(|language| language.name()),
@@ -62,7 +62,7 @@ impl Editor {
                     let brackets_by_accent = buffer_snapshot
                         .fetch_bracket_ranges(
                             buffer_range.start..buffer_range.end,
-                            Some((&buffer_version, fetched_chunks)),
+                            Some(fetched_chunks),
                         )
                         .into_iter()
                         .flat_map(|(chunk_range, pairs)| {

crates/language/src/buffer.rs 🔗

@@ -22,8 +22,8 @@ pub use crate::{
     proto,
 };
 use anyhow::{Context as _, Result};
+use clock::Lamport;
 pub use clock::ReplicaId;
-use clock::{Global, Lamport};
 use collections::{HashMap, HashSet};
 use fs::MTime;
 use futures::channel::oneshot;
@@ -33,7 +33,7 @@ use gpui::{
 };
 
 use lsp::{LanguageServerId, NumberOrString};
-use parking_lot::{Mutex, RawMutex, lock_api::MutexGuard};
+use parking_lot::Mutex;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 use settings::WorktreeId;
@@ -130,29 +130,37 @@ pub struct Buffer {
     has_unsaved_edits: Cell<(clock::Global, bool)>,
     change_bits: Vec<rc::Weak<Cell<bool>>>,
     _subscriptions: Vec<gpui::Subscription>,
-    tree_sitter_data: Arc<Mutex<TreeSitterData>>,
+    tree_sitter_data: Arc<TreeSitterData>,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug)]
 pub struct TreeSitterData {
     chunks: RowChunks,
-    brackets_by_chunks: Vec<Option<Vec<BracketMatch<usize>>>>,
+    brackets_by_chunks: Mutex<Vec<Option<Vec<BracketMatch<usize>>>>>,
 }
 
 const MAX_ROWS_IN_A_CHUNK: u32 = 50;
 
 impl TreeSitterData {
-    fn clear(&mut self) {
-        self.brackets_by_chunks = vec![None; self.chunks.len()];
+    fn clear(&mut self, snapshot: text::BufferSnapshot) {
+        self.chunks = RowChunks::new(snapshot, MAX_ROWS_IN_A_CHUNK);
+        self.brackets_by_chunks.get_mut().clear();
+        self.brackets_by_chunks
+            .get_mut()
+            .resize(self.chunks.len(), None);
     }
 
     fn new(snapshot: text::BufferSnapshot) -> Self {
         let chunks = RowChunks::new(snapshot, MAX_ROWS_IN_A_CHUNK);
         Self {
-            brackets_by_chunks: vec![None; chunks.len()],
+            brackets_by_chunks: Mutex::new(vec![None; chunks.len()]),
             chunks,
         }
     }
+
+    fn version(&self) -> &clock::Global {
+        self.chunks.version()
+    }
 }
 
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -176,7 +184,7 @@ pub struct BufferSnapshot {
     remote_selections: TreeMap<ReplicaId, SelectionSet>,
     language: Option<Arc<Language>>,
     non_text_state_update_count: usize,
-    tree_sitter_data: Arc<Mutex<TreeSitterData>>,
+    tree_sitter_data: Arc<TreeSitterData>,
 }
 
 /// The kind and amount of indentation in a particular line. For now,
@@ -1062,7 +1070,7 @@ impl Buffer {
         let tree_sitter_data = TreeSitterData::new(snapshot);
         Self {
             saved_mtime,
-            tree_sitter_data: Arc::new(Mutex::new(tree_sitter_data)),
+            tree_sitter_data: Arc::new(tree_sitter_data),
             saved_version: buffer.version(),
             preview_version: buffer.version(),
             reload_task: None,
@@ -1119,7 +1127,7 @@ impl Buffer {
                 file: None,
                 diagnostics: Default::default(),
                 remote_selections: Default::default(),
-                tree_sitter_data: Arc::new(Mutex::new(tree_sitter_data)),
+                tree_sitter_data: Arc::new(tree_sitter_data),
                 language,
                 non_text_state_update_count: 0,
             }
@@ -1141,7 +1149,7 @@ impl Buffer {
         BufferSnapshot {
             text,
             syntax,
-            tree_sitter_data: Arc::new(Mutex::new(tree_sitter_data)),
+            tree_sitter_data: Arc::new(tree_sitter_data),
             file: None,
             diagnostics: Default::default(),
             remote_selections: Default::default(),
@@ -1170,7 +1178,7 @@ impl Buffer {
         BufferSnapshot {
             text,
             syntax,
-            tree_sitter_data: Arc::new(Mutex::new(tree_sitter_data)),
+            tree_sitter_data: Arc::new(tree_sitter_data),
             file: None,
             diagnostics: Default::default(),
             remote_selections: Default::default(),
@@ -1187,10 +1195,16 @@ impl Buffer {
         syntax_map.interpolate(&text);
         let syntax = syntax_map.snapshot();
 
+        let tree_sitter_data = if self.text.version() != *self.tree_sitter_data.version() {
+            Arc::new(TreeSitterData::new(text.clone()))
+        } else {
+            self.tree_sitter_data.clone()
+        };
+
         BufferSnapshot {
             text,
             syntax,
-            tree_sitter_data: self.tree_sitter_data.clone(),
+            tree_sitter_data,
             file: self.file.clone(),
             remote_selections: self.remote_selections.clone(),
             diagnostics: self.diagnostics.clone(),
@@ -1624,6 +1638,16 @@ impl Buffer {
         self.sync_parse_timeout = timeout;
     }
 
+    fn invalidate_tree_sitter_data(&mut self, snapshot: text::BufferSnapshot) {
+        match Arc::get_mut(&mut self.tree_sitter_data) {
+            Some(tree_sitter_data) => tree_sitter_data.clear(snapshot),
+            None => {
+                let tree_sitter_data = TreeSitterData::new(snapshot);
+                self.tree_sitter_data = Arc::new(tree_sitter_data)
+            }
+        }
+    }
+
     /// Called after an edit to synchronize the buffer's main parse tree with
     /// the buffer's new underlying state.
     ///
@@ -1648,6 +1672,9 @@ impl Buffer {
     /// for the same buffer, we only initiate a new parse if we are not already
     /// parsing in the background.
     pub fn reparse(&mut self, cx: &mut Context<Self>, may_block: bool) {
+        if self.text.version() != *self.tree_sitter_data.version() {
+            self.invalidate_tree_sitter_data(self.text.snapshot());
+        }
         if self.reparse.is_some() {
             return;
         }
@@ -1749,7 +1776,9 @@ impl Buffer {
         self.syntax_map.lock().did_parse(syntax_snapshot);
         self.request_autoindent(cx);
         self.parse_status.0.send(ParseStatus::Idle).unwrap();
-        self.tree_sitter_data.lock().clear();
+        if self.text.version() != *self.tree_sitter_data.version() {
+            self.invalidate_tree_sitter_data(self.text.snapshot());
+        }
         cx.emit(BufferEvent::Reparsed);
         cx.notify();
     }
@@ -4281,155 +4310,123 @@ impl BufferSnapshot {
     pub fn fetch_bracket_ranges(
         &self,
         range: Range<usize>,
-        known_chunks: Option<(&Global, &HashSet<Range<BufferRow>>)>,
+        known_chunks: Option<&HashSet<Range<BufferRow>>>,
     ) -> HashMap<Range<BufferRow>, Vec<BracketMatch<usize>>> {
-        let mut tree_sitter_data = self.latest_tree_sitter_data().clone();
-
-        let known_chunks = match known_chunks {
-            Some((known_version, known_chunks)) => {
-                if !tree_sitter_data
-                    .chunks
-                    .version()
-                    .changed_since(known_version)
-                {
-                    known_chunks.clone()
-                } else {
-                    HashSet::default()
-                }
-            }
-            None => HashSet::default(),
-        };
-
-        let mut new_bracket_matches = HashMap::default();
         let mut all_bracket_matches = HashMap::default();
 
-        for chunk in tree_sitter_data
+        for chunk in self
+            .tree_sitter_data
             .chunks
             .applicable_chunks(&[self.anchor_before(range.start)..self.anchor_after(range.end)])
         {
-            if known_chunks.contains(&chunk.row_range()) {
+            if known_chunks.is_some_and(|chunks| chunks.contains(&chunk.row_range())) {
                 continue;
             }
-            let Some(chunk_range) = tree_sitter_data.chunks.chunk_range(chunk) else {
+            let Some(chunk_range) = self.tree_sitter_data.chunks.chunk_range(chunk) else {
                 continue;
             };
-            let chunk_range = chunk_range.to_offset(&tree_sitter_data.chunks.snapshot);
-
-            let bracket_matches = match tree_sitter_data.brackets_by_chunks[chunk.id].take() {
-                Some(cached_brackets) => cached_brackets,
-                None => {
-                    let mut all_brackets = Vec::new();
-                    let mut opens = Vec::new();
-                    let mut color_pairs = Vec::new();
-
-                    let mut matches =
-                        self.syntax
-                            .matches(chunk_range.clone(), &self.text, |grammar| {
-                                grammar.brackets_config.as_ref().map(|c| &c.query)
-                            });
-                    let configs = matches
-                        .grammars()
-                        .iter()
-                        .map(|grammar| grammar.brackets_config.as_ref().unwrap())
-                        .collect::<Vec<_>>();
-
-                    while let Some(mat) = matches.peek() {
-                        let mut open = None;
-                        let mut close = None;
-                        let syntax_layer_depth = mat.depth;
-                        let config = configs[mat.grammar_index];
-                        let pattern = &config.patterns[mat.pattern_index];
-                        for capture in mat.captures {
-                            if capture.index == config.open_capture_ix {
-                                open = Some(capture.node.byte_range());
-                            } else if capture.index == config.close_capture_ix {
-                                close = Some(capture.node.byte_range());
-                            }
-                        }
+            let chunk_range = chunk_range.to_offset(&self);
 
-                        matches.advance();
+            if let Some(cached_brackets) =
+                &self.tree_sitter_data.brackets_by_chunks.lock()[chunk.id]
+            {
+                all_bracket_matches.insert(chunk.row_range(), cached_brackets.clone());
+                continue;
+            }
 
-                        let Some((open_range, close_range)) = open.zip(close) else {
-                            continue;
-                        };
+            let mut all_brackets = Vec::new();
+            let mut opens = Vec::new();
+            let mut color_pairs = Vec::new();
 
-                        let bracket_range = open_range.start..=close_range.end;
-                        if !bracket_range.overlaps(&chunk_range) {
-                            continue;
-                        }
+            let mut matches = self
+                .syntax
+                .matches(chunk_range.clone(), &self.text, |grammar| {
+                    grammar.brackets_config.as_ref().map(|c| &c.query)
+                });
+            let configs = matches
+                .grammars()
+                .iter()
+                .map(|grammar| grammar.brackets_config.as_ref().unwrap())
+                .collect::<Vec<_>>();
+
+            while let Some(mat) = matches.peek() {
+                let mut open = None;
+                let mut close = None;
+                let syntax_layer_depth = mat.depth;
+                let config = configs[mat.grammar_index];
+                let pattern = &config.patterns[mat.pattern_index];
+                for capture in mat.captures {
+                    if capture.index == config.open_capture_ix {
+                        open = Some(capture.node.byte_range());
+                    } else if capture.index == config.close_capture_ix {
+                        close = Some(capture.node.byte_range());
+                    }
+                }
 
-                        let index = all_brackets.len();
-                        all_brackets.push(BracketMatch {
-                            open_range: open_range.clone(),
-                            close_range: close_range.clone(),
-                            newline_only: pattern.newline_only,
-                            syntax_layer_depth,
-                            color_index: None,
-                        });
+                matches.advance();
 
-                        // Certain languages have "brackets" that are not brackets, e.g. tags. and such
-                        // bracket will match the entire tag with all text inside.
-                        // For now, avoid highlighting any pair that has more than single char in each bracket.
-                        // We need to  colorize `<Element/>` bracket pairs, so cannot make this check stricter.
-                        let should_color = !pattern.rainbow_exclude
-                            && (open_range.len() == 1 || close_range.len() == 1);
-                        if should_color {
-                            opens.push(open_range.clone());
-                            color_pairs.push((open_range, close_range, index));
-                        }
-                    }
+                let Some((open_range, close_range)) = open.zip(close) else {
+                    continue;
+                };
 
-                    opens.sort_by_key(|r| (r.start, r.end));
-                    opens.dedup_by(|a, b| a.start == b.start && a.end == b.end);
-                    color_pairs.sort_by_key(|(_, close, _)| close.end);
+                let bracket_range = open_range.start..=close_range.end;
+                if !bracket_range.overlaps(&chunk_range) {
+                    continue;
+                }
 
-                    let mut open_stack = Vec::new();
-                    let mut open_index = 0;
-                    for (open, close, index) in color_pairs {
-                        while open_index < opens.len() && opens[open_index].start < close.start {
-                            open_stack.push(opens[open_index].clone());
-                            open_index += 1;
-                        }
+                let index = all_brackets.len();
+                all_brackets.push(BracketMatch {
+                    open_range: open_range.clone(),
+                    close_range: close_range.clone(),
+                    newline_only: pattern.newline_only,
+                    syntax_layer_depth,
+                    color_index: None,
+                });
 
-                        if open_stack.last() == Some(&open) {
-                            let depth_index = open_stack.len() - 1;
-                            all_brackets[index].color_index = Some(depth_index);
-                            open_stack.pop();
-                        }
-                    }
+                // Certain languages have "brackets" that are not brackets, e.g. tags. and such
+                // bracket will match the entire tag with all text inside.
+                // For now, avoid highlighting any pair that has more than single char in each bracket.
+                // We need to  colorize `<Element/>` bracket pairs, so cannot make this check stricter.
+                let should_color =
+                    !pattern.rainbow_exclude && (open_range.len() == 1 || close_range.len() == 1);
+                if should_color {
+                    opens.push(open_range.clone());
+                    color_pairs.push((open_range, close_range, index));
+                }
+            }
 
-                    all_brackets.sort_by_key(|bracket_match| {
-                        (bracket_match.open_range.start, bracket_match.open_range.end)
-                    });
-                    new_bracket_matches.insert(chunk.id, all_brackets.clone());
-                    all_brackets
+            opens.sort_by_key(|r| (r.start, r.end));
+            opens.dedup_by(|a, b| a.start == b.start && a.end == b.end);
+            color_pairs.sort_by_key(|(_, close, _)| close.end);
+
+            let mut open_stack = Vec::new();
+            let mut open_index = 0;
+            for (open, close, index) in color_pairs {
+                while open_index < opens.len() && opens[open_index].start < close.start {
+                    open_stack.push(opens[open_index].clone());
+                    open_index += 1;
                 }
-            };
-            all_bracket_matches.insert(chunk.row_range(), bracket_matches);
-        }
 
-        let mut latest_tree_sitter_data = self.latest_tree_sitter_data();
-        if latest_tree_sitter_data.chunks.version() == &self.version {
-            for (chunk_id, new_matches) in new_bracket_matches {
-                let old_chunks = &mut latest_tree_sitter_data.brackets_by_chunks[chunk_id];
-                if old_chunks.is_none() {
-                    *old_chunks = Some(new_matches);
+                if open_stack.last() == Some(&open) {
+                    let depth_index = open_stack.len() - 1;
+                    all_brackets[index].color_index = Some(depth_index);
+                    open_stack.pop();
                 }
             }
-        }
 
-        all_bracket_matches
-    }
+            all_brackets.sort_by_key(|bracket_match| {
+                (bracket_match.open_range.start, bracket_match.open_range.end)
+            });
 
-    fn latest_tree_sitter_data(&self) -> MutexGuard<'_, RawMutex, TreeSitterData> {
-        let mut tree_sitter_data = self.tree_sitter_data.lock();
-        if self
-            .version
-            .changed_since(tree_sitter_data.chunks.version())
-        {
-            *tree_sitter_data = TreeSitterData::new(self.text.clone());
+            if let empty_slot @ None =
+                &mut self.tree_sitter_data.brackets_by_chunks.lock()[chunk.id]
+            {
+                *empty_slot = Some(all_brackets.clone());
+            }
+            all_bracket_matches.insert(chunk.row_range(), all_brackets);
         }
-        tree_sitter_data
+
+        all_bracket_matches
     }
 
     pub fn all_bracket_ranges(

crates/language/src/buffer/row_chunk.rs 🔗

@@ -19,7 +19,7 @@ use crate::BufferRow;
 /// <https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#inlayHintParams>
 #[derive(Clone)]
 pub struct RowChunks {
-    pub(crate) snapshot: text::BufferSnapshot,
+    snapshot: text::BufferSnapshot,
     chunks: Arc<[RowChunk]>,
 }
 

crates/text/src/text.rs 🔗

@@ -2321,8 +2321,13 @@ impl BufferSnapshot {
         } else if anchor.is_max() {
             self.visible_text.len()
         } else {
-            debug_assert!(anchor.buffer_id == Some(self.remote_id));
-            debug_assert!(self.version.observed(anchor.timestamp));
+            debug_assert_eq!(anchor.buffer_id, Some(self.remote_id));
+            debug_assert!(
+                self.version.observed(anchor.timestamp),
+                "Anchor timestamp {:?} not observed by buffer {:?}",
+                anchor.timestamp,
+                self.version
+            );
             let anchor_key = InsertionFragmentKey {
                 timestamp: anchor.timestamp,
                 split_offset: anchor.offset,