Restructure handling of changed regions when reparsing

Max Brunsfeld created

Change summary

crates/language/src/syntax_map.rs | 462 ++++++++++++++++----------------
1 file changed, 231 insertions(+), 231 deletions(-)

Detailed changes

crates/language/src/syntax_map.rs 🔗

@@ -2,14 +2,12 @@ use crate::{
     Grammar, InjectionConfig, Language, LanguageRegistry, QueryCursorHandle, TextProvider,
     ToTreeSitterPoint,
 };
-use collections::HashMap;
 use std::{
     borrow::Cow, cell::RefCell, cmp::Ordering, collections::BinaryHeap, ops::Range, sync::Arc,
 };
 use sum_tree::{Bias, SeekTarget, SumTree};
 use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint};
 use tree_sitter::{Parser, Tree};
-use util::post_inc;
 
 thread_local! {
     static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
@@ -42,28 +40,26 @@ struct SyntaxLayerSummary {
     last_layer_range: Range<Anchor>,
 }
 
-#[derive(Clone, Debug)]
-struct Depth(usize);
+#[derive(Debug)]
+struct DepthAndRange(usize, Range<Anchor>);
 
-#[derive(Clone, Debug)]
-struct MaxPosition(Anchor);
+#[derive(Debug)]
+struct DepthAndMaxPosition(usize, Anchor);
 
-enum ReparseStep {
-    CreateLayer {
-        depth: usize,
-        language: Arc<Language>,
-        ranges: Vec<tree_sitter::Range>,
-    },
-    EnterChangedRange {
-        id: usize,
-        depth: usize,
-        range: Range<usize>,
-    },
-    LeaveChangedRange {
-        id: usize,
-        depth: usize,
-        range: Range<usize>,
-    },
+#[derive(Debug)]
+struct DepthAndRangeOrMaxPosition(usize, Range<Anchor>, Anchor);
+
+struct ReparseStep {
+    depth: usize,
+    language: Arc<Language>,
+    ranges: Vec<tree_sitter::Range>,
+    range: Range<Anchor>,
+}
+
+#[derive(Debug, PartialEq, Eq)]
+struct ChangedRegion {
+    depth: usize,
+    range: Range<Anchor>,
 }
 
 impl SyntaxMap {
@@ -130,7 +126,16 @@ impl SyntaxSnapshot {
 
         for depth in 0..=max_depth {
             let mut edits = &edits[..];
-            layers.push_tree(cursor.slice(&Depth(depth), Bias::Left, text), text);
+            if cursor.start().max_depth < depth {
+                layers.push_tree(
+                    cursor.slice(
+                        &DepthAndRange(depth, Anchor::MIN..Anchor::MAX),
+                        Bias::Left,
+                        text,
+                    ),
+                    text,
+                );
+            }
 
             while let Some(layer) = cursor.item() {
                 let mut endpoints = text.summaries_for_anchors::<(usize, Point), _>([
@@ -150,10 +155,7 @@ impl SyntaxSnapshot {
                 if first_edit.new.start.0 > layer_range.end.0 {
                     layers.push_tree(
                         cursor.slice(
-                            &(
-                                Depth(depth),
-                                MaxPosition(text.anchor_before(first_edit.new.start.0)),
-                            ),
+                            &DepthAndMaxPosition(depth, text.anchor_before(first_edit.new.start.0)),
                             Bias::Left,
                             text,
                         ),
@@ -183,8 +185,8 @@ impl SyntaxSnapshot {
                     }
 
                     // Apply any edits that intersect this layer to the layer's syntax tree.
-                    if edit.new.start.0 >= start_byte {
-                        layer.tree.edit(&tree_sitter::InputEdit {
+                    let tree_edit = if edit.new.start.0 >= start_byte {
+                        tree_sitter::InputEdit {
                             start_byte: edit.new.start.0 - start_byte,
                             old_end_byte: edit.new.start.0 - start_byte
                                 + (edit.old.end.0 - edit.old.start.0),
@@ -194,16 +196,20 @@ impl SyntaxSnapshot {
                                 + (edit.old.end.1 - edit.old.start.1))
                                 .to_ts_point(),
                             new_end_position: (edit.new.end.1 - start_point).to_ts_point(),
-                        });
+                        }
                     } else {
-                        layer.tree.edit(&tree_sitter::InputEdit {
+                        tree_sitter::InputEdit {
                             start_byte: 0,
                             old_end_byte: edit.new.end.0 - start_byte,
                             new_end_byte: 0,
                             start_position: Default::default(),
                             old_end_position: (edit.new.end.1 - start_point).to_ts_point(),
                             new_end_position: Default::default(),
-                        });
+                        }
+                    };
+
+                    layer.tree.edit(&tree_edit);
+                    if edit.new.start.0 < start_byte {
                         break;
                     }
                 }
@@ -228,184 +234,157 @@ impl SyntaxSnapshot {
         cursor.next(&text);
         let mut layers = SumTree::new();
 
-        let mut next_change_id = 0;
-        let mut current_changes = HashMap::default();
+        let mut changed_regions = Vec::<ChangedRegion>::new();
         let mut queue = BinaryHeap::new();
-        queue.push(ReparseStep::CreateLayer {
+        queue.push(ReparseStep {
             depth: 0,
             language: language.clone(),
             ranges: Vec::new(),
+            range: Anchor::MIN..Anchor::MAX,
         });
 
-        while let Some(step) = queue.pop() {
-            match step {
-                ReparseStep::CreateLayer {
-                    depth,
-                    language,
-                    ranges,
-                } => {
-                    let range;
-                    let start_point;
-                    let start_byte;
-                    let end_byte;
-                    if let Some((first, last)) = ranges.first().zip(ranges.last()) {
-                        start_point = first.start_point;
-                        start_byte = first.start_byte;
-                        end_byte = last.end_byte;
-                        range = text.anchor_before(start_byte)..text.anchor_after(end_byte);
-                    } else {
-                        start_point = Point::zero().to_ts_point();
-                        start_byte = 0;
-                        end_byte = text.len();
-                        range = Anchor::MIN..Anchor::MAX;
-                    };
-
-                    let target = (Depth(depth), range.clone());
-                    if target.cmp(cursor.start(), &text).is_gt() {
-                        if current_changes.is_empty() {
-                            let slice = cursor.slice(&target, Bias::Left, text);
-                            layers.push_tree(slice, &text);
-                        } else {
-                            while let Some(layer) = cursor.item() {
-                                if layer.depth > depth
-                                    || layer.depth == depth
-                                        && layer.range.start.cmp(&range.end, text).is_ge()
-                                {
-                                    break;
-                                }
-                                if !layer_is_changed(layer, text, &current_changes) {
-                                    layers.push(layer.clone(), text);
-                                }
-                                cursor.next(text);
-                            }
-                        }
+        loop {
+            let step = queue.pop();
+            let (depth, range) = if let Some(step) = &step {
+                (step.depth, step.range.clone())
+            } else {
+                (cursor.start().max_depth, Anchor::MAX..Anchor::MAX)
+            };
+
+            let target = DepthAndRange(depth, range.clone());
+            if target.cmp(cursor.start(), &text).is_gt() {
+                let change_start_anchor = changed_regions
+                    .first()
+                    .map_or(Anchor::MAX, |region| region.range.start);
+                let seek_target =
+                    DepthAndRangeOrMaxPosition(depth, range.clone(), change_start_anchor);
+                let slice = cursor.slice(&seek_target, Bias::Left, text);
+                layers.push_tree(slice, &text);
+
+                while let Some(layer) = cursor.item() {
+                    if target.cmp(&cursor.end(text), text).is_le() {
+                        break;
                     }
-
-                    let mut old_layer = cursor.item();
-                    if let Some(layer) = old_layer {
-                        if layer.range.to_offset(text) == (start_byte..end_byte) {
-                            cursor.next(&text);
-                        } else {
-                            old_layer = None;
+                    if layer_is_changed(layer, text, &changed_regions) {
+                        let region = ChangedRegion {
+                            depth: depth + 1,
+                            range: layer.range.clone(),
+                        };
+                        if let Err(i) =
+                            changed_regions.binary_search_by(|probe| probe.cmp(&region, text))
+                        {
+                            changed_regions.insert(i, region);
                         }
-                    }
-
-                    let grammar = if let Some(grammar) = language.grammar.as_deref() {
-                        grammar
                     } else {
-                        continue;
-                    };
-
-                    let tree;
-                    let changed_ranges;
-                    if let Some(old_layer) = old_layer {
-                        tree = parse_text(
-                            grammar,
-                            text.as_rope(),
-                            Some(old_layer.tree.clone()),
-                            ranges,
-                        );
-
-                        changed_ranges = old_layer
-                            .tree
-                            .changed_ranges(&tree)
-                            .map(|r| r.start_byte..r.end_byte)
-                            .collect();
-                    } else {
-                        tree = parse_text(grammar, text.as_rope(), None, ranges);
-                        changed_ranges = vec![0..end_byte - start_byte];
+                        layers.push(layer.clone(), text);
                     }
 
-                    layers.push(
-                        SyntaxLayer {
-                            depth,
-                            range,
-                            tree: tree.clone(),
-                            language: language.clone(),
-                        },
-                        &text,
-                    );
-
-                    if let (Some((config, registry)), false) = (
-                        grammar.injection_config.as_ref().zip(registry.as_ref()),
-                        changed_ranges.is_empty(),
-                    ) {
-                        let depth = depth + 1;
-                        queue.extend(changed_ranges.iter().flat_map(|range| {
-                            let id = post_inc(&mut next_change_id);
-                            let range = start_byte + range.start..start_byte + range.end;
-                            [
-                                ReparseStep::EnterChangedRange {
-                                    id,
-                                    depth,
-                                    range: range.clone(),
-                                },
-                                ReparseStep::LeaveChangedRange {
-                                    id,
-                                    depth,
-                                    range: range.clone(),
-                                },
-                            ]
-                        }));
-
-                        get_injections(
-                            config,
-                            text,
-                            &tree,
-                            registry,
-                            depth,
-                            start_byte,
-                            Point::from_ts_point(start_point),
-                            &changed_ranges,
-                            &mut queue,
-                        );
-                    }
+                    cursor.next(text);
                 }
-                ReparseStep::EnterChangedRange { id, depth, range } => {
-                    let range = text.anchor_before(range.start)..text.anchor_after(range.end);
-                    if current_changes.is_empty() {
-                        let target = (Depth(depth), range.start..Anchor::MAX);
-                        let slice = cursor.slice(&target, Bias::Left, text);
-                        layers.push_tree(slice, text);
-                    } else {
-                        while let Some(layer) = cursor.item() {
-                            if layer.depth > depth
-                                || layer.depth == depth
-                                    && layer.range.end.cmp(&range.start, text).is_gt()
-                            {
-                                break;
-                            }
-                            if !layer_is_changed(layer, text, &current_changes) {
-                                layers.push(layer.clone(), text);
-                            }
-                            cursor.next(text);
-                        }
-                    }
 
-                    current_changes.insert(id, range);
+                changed_regions.retain(|region| {
+                    region.depth > depth
+                        || (region.depth == depth
+                            && region.range.end.cmp(&range.start, text).is_gt())
+                });
+            }
+
+            let (ranges, language) = if let Some(step) = step {
+                (step.ranges, step.language)
+            } else {
+                break;
+            };
+
+            let start_point;
+            let start_byte;
+            let end_byte;
+            if let Some((first, last)) = ranges.first().zip(ranges.last()) {
+                start_point = first.start_point;
+                start_byte = first.start_byte;
+                end_byte = last.end_byte;
+            } else {
+                start_point = Point::zero().to_ts_point();
+                start_byte = 0;
+                end_byte = text.len();
+            };
+
+            let mut old_layer = cursor.item();
+            if let Some(layer) = old_layer {
+                if layer.range.to_offset(text) == (start_byte..end_byte) {
+                    cursor.next(&text);
+                } else {
+                    old_layer = None;
                 }
-                ReparseStep::LeaveChangedRange { id, depth, range } => {
-                    let range = text.anchor_before(range.start)..text.anchor_after(range.end);
-                    while let Some(layer) = cursor.item() {
-                        if layer.depth > depth
-                            || layer.depth == depth
-                                && layer.range.start.cmp(&range.end, text).is_ge()
-                        {
-                            break;
-                        }
-                        if !layer_is_changed(layer, text, &current_changes) {
-                            layers.push(layer.clone(), text);
-                        }
-                        cursor.next(text);
-                    }
+            }
+
+            let grammar = if let Some(grammar) = language.grammar.as_deref() {
+                grammar
+            } else {
+                continue;
+            };
+
+            let tree;
+            let changed_ranges;
+            if let Some(old_layer) = old_layer {
+                tree = parse_text(
+                    grammar,
+                    text.as_rope(),
+                    Some(old_layer.tree.clone()),
+                    ranges,
+                );
 
-                    current_changes.remove(&id);
+                changed_ranges = old_layer
+                    .tree
+                    .changed_ranges(&tree)
+                    .map(|r| r.start_byte..r.end_byte)
+                    .collect();
+            } else {
+                tree = parse_text(grammar, text.as_rope(), None, ranges);
+                changed_ranges = vec![0..end_byte - start_byte];
+            }
+
+            layers.push(
+                SyntaxLayer {
+                    depth,
+                    range,
+                    tree: tree.clone(),
+                    language: language.clone(),
+                },
+                &text,
+            );
+
+            if let (Some((config, registry)), false) = (
+                grammar.injection_config.as_ref().zip(registry.as_ref()),
+                changed_ranges.is_empty(),
+            ) {
+                let depth = depth + 1;
+
+                for range in &changed_ranges {
+                    let region = ChangedRegion {
+                        depth,
+                        range: text.anchor_before(range.start)..text.anchor_after(range.end),
+                    };
+                    if let Err(i) =
+                        changed_regions.binary_search_by(|probe| probe.cmp(&region, text))
+                    {
+                        changed_regions.insert(i, region);
+                    }
                 }
+
+                get_injections(
+                    config,
+                    text,
+                    &tree,
+                    registry,
+                    depth,
+                    start_byte,
+                    Point::from_ts_point(start_point),
+                    &changed_ranges,
+                    &mut queue,
+                );
             }
         }
 
-        let slice = cursor.suffix(&text);
-        layers.push_tree(slice, &text);
         drop(cursor);
         self.layers = layers;
     }
@@ -512,7 +491,7 @@ fn get_injections(
     start_byte: usize,
     start_point: Point,
     query_ranges: &[Range<usize>],
-    stack: &mut BinaryHeap<ReparseStep>,
+    queue: &mut BinaryHeap<ReparseStep>,
 ) -> bool {
     let mut result = false;
     let mut query_cursor = QueryCursorHandle::new();
@@ -547,7 +526,7 @@ fn get_injections(
                     continue;
                 }
             }
-            prev_match = Some((mat.pattern_index, content_range));
+            prev_match = Some((mat.pattern_index, content_range.clone()));
 
             let language_name = config.languages_by_pattern_ix[mat.pattern_index]
                 .as_ref()
@@ -566,10 +545,13 @@ fn get_injections(
             if let Some(language_name) = language_name {
                 if let Some(language) = language_registry.get_language(language_name.as_ref()) {
                     result = true;
-                    stack.push(ReparseStep::CreateLayer {
+                    let range = text.anchor_before(content_range.start)
+                        ..text.anchor_after(content_range.end);
+                    queue.push(ReparseStep {
                         depth,
                         language,
                         ranges: content_ranges,
+                        range,
                     })
                 }
             }
@@ -581,11 +563,11 @@ fn get_injections(
 fn layer_is_changed(
     layer: &SyntaxLayer,
     text: &BufferSnapshot,
-    changed_ranges: &HashMap<usize, Range<Anchor>>,
+    changed_regions: &[ChangedRegion],
 ) -> bool {
-    changed_ranges.values().any(|range| {
-        let is_before_layer = range.end.cmp(&layer.range.start, text).is_le();
-        let is_after_layer = range.start.cmp(&layer.range.end, text).is_ge();
+    changed_regions.iter().any(|region| {
+        let is_before_layer = region.range.end.cmp(&layer.range.start, text).is_le();
+        let is_after_layer = region.range.start.cmp(&layer.range.end, text).is_ge();
         !is_before_layer && !is_after_layer
     })
 }
@@ -598,22 +580,6 @@ impl std::ops::Deref for SyntaxMap {
     }
 }
 
-impl ReparseStep {
-    fn sort_key(&self) -> (usize, Range<usize>) {
-        match self {
-            ReparseStep::CreateLayer { depth, ranges, .. } => (
-                *depth,
-                ranges.first().map_or(0, |r| r.start_byte)
-                    ..ranges.last().map_or(usize::MAX, |r| r.end_byte),
-            ),
-            ReparseStep::EnterChangedRange { depth, range, .. } => {
-                (*depth, range.start..usize::MAX)
-            }
-            ReparseStep::LeaveChangedRange { depth, range, .. } => (*depth, range.end..usize::MAX),
-        }
-    }
-}
-
 impl PartialEq for ReparseStep {
     fn eq(&self, _: &Self) -> bool {
         false
@@ -630,14 +596,32 @@ impl PartialOrd for ReparseStep {
 
 impl Ord for ReparseStep {
     fn cmp(&self, other: &Self) -> Ordering {
-        let (depth_a, range_a) = self.sort_key();
-        let (depth_b, range_b) = other.sort_key();
-        Ord::cmp(&depth_b, &depth_a)
+        let range_a = self.range();
+        let range_b = other.range();
+        Ord::cmp(&other.depth, &self.depth)
             .then_with(|| Ord::cmp(&range_b.start, &range_a.start))
             .then_with(|| Ord::cmp(&range_a.end, &range_b.end))
     }
 }
 
+impl ReparseStep {
+    fn range(&self) -> Range<usize> {
+        let start = self.ranges.first().map_or(0, |r| r.start_byte);
+        let end = self.ranges.last().map_or(0, |r| r.end_byte);
+        start..end
+    }
+}
+
+impl ChangedRegion {
+    fn cmp(&self, other: &Self, buffer: &BufferSnapshot) -> Ordering {
+        let range_a = &self.range;
+        let range_b = &other.range;
+        Ord::cmp(&self.depth, &other.depth)
+            .then_with(|| range_a.start.cmp(&range_b.start, buffer))
+            .then_with(|| range_b.end.cmp(&range_a.end, buffer))
+    }
+}
+
 impl Default for SyntaxLayerSummary {
     fn default() -> Self {
         Self {
@@ -666,29 +650,45 @@ impl sum_tree::Summary for SyntaxLayerSummary {
     }
 }
 
-impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for Depth {
-    fn cmp(&self, cursor_location: &SyntaxLayerSummary, _: &BufferSnapshot) -> Ordering {
+impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRange {
+    fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
         Ord::cmp(&self.0, &cursor_location.max_depth)
+            .then_with(|| {
+                self.1
+                    .start
+                    .cmp(&cursor_location.last_layer_range.start, buffer)
+            })
+            .then_with(|| {
+                cursor_location
+                    .last_layer_range
+                    .end
+                    .cmp(&self.1.end, buffer)
+            })
     }
 }
 
-impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for (Depth, MaxPosition) {
+impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndMaxPosition {
     fn cmp(&self, cursor_location: &SyntaxLayerSummary, text: &BufferSnapshot) -> Ordering {
-        self.0
-            .cmp(&cursor_location, text)
-            .then_with(|| (self.1).0.cmp(&cursor_location.range.end, text))
+        Ord::cmp(&self.0, &cursor_location.max_depth)
+            .then_with(|| self.1.cmp(&cursor_location.range.end, text))
     }
 }
 
-impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for (Depth, Range<Anchor>) {
+impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRangeOrMaxPosition {
     fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
-        self.0
-            .cmp(&cursor_location, buffer)
-            .then_with(|| {
-                self.1
-                    .start
-                    .cmp(&cursor_location.last_layer_range.start, buffer)
-            })
+        let cmp = Ord::cmp(&self.0, &cursor_location.max_depth);
+        if cmp.is_ne() {
+            return cmp;
+        }
+
+        let cmp = self.2.cmp(&cursor_location.range.end, buffer);
+        if cmp.is_gt() {
+            return Ordering::Greater;
+        }
+
+        self.1
+            .start
+            .cmp(&cursor_location.last_layer_range.start, buffer)
             .then_with(|| {
                 cursor_location
                     .last_layer_range