Start work on handling combined injections in SyntaxMap

Max Brunsfeld created

Change summary

Cargo.lock                        |   5 
Cargo.toml                        |   2 
crates/language/Cargo.toml        |   1 
crates/language/src/language.rs   |  19 +
crates/language/src/syntax_map.rs | 343 +++++++++++++++++++++++++-------
5 files changed, 288 insertions(+), 82 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3005,6 +3005,7 @@ dependencies = [
  "text",
  "theme",
  "tree-sitter",
+ "tree-sitter-embedded-template",
  "tree-sitter-html",
  "tree-sitter-javascript",
  "tree-sitter-json 0.19.0",
@@ -6381,8 +6382,8 @@ dependencies = [
 
 [[package]]
 name = "tree-sitter"
-version = "0.20.8"
-source = "git+https://github.com/tree-sitter/tree-sitter?rev=366210ae925d7ea0891bc7a0c738f60c77c04d7b#366210ae925d7ea0891bc7a0c738f60c77c04d7b"
+version = "0.20.9"
+source = "git+https://github.com/tree-sitter/tree-sitter?rev=f0177f216e3f76a5f68e792b6f9e45fd32383eb6#f0177f216e3f76a5f68e792b6f9e45fd32383eb6"
 dependencies = [
  "cc",
  "regex",

Cargo.toml 🔗

@@ -65,7 +65,7 @@ serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
 rand = { version = "0.8" }
 
 [patch.crates-io]
-tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "366210ae925d7ea0891bc7a0c738f60c77c04d7b" }
+tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "f0177f216e3f76a5f68e792b6f9e45fd32383eb6" }
 async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }
 
 # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457

crates/language/Cargo.toml 🔗

@@ -72,4 +72,5 @@ tree-sitter-rust = "*"
 tree-sitter-python = "*"
 tree-sitter-typescript = "*"
 tree-sitter-ruby = "*"
+tree-sitter-embedded-template = "*"
 unindent = "0.1.7"

crates/language/src/language.rs 🔗

@@ -28,6 +28,7 @@ use std::{
     any::Any,
     cell::RefCell,
     fmt::Debug,
+    hash::Hash,
     mem,
     ops::Range,
     path::{Path, PathBuf},
@@ -643,6 +644,10 @@ impl Language {
         self.adapter.clone()
     }
 
+    pub fn id(&self) -> Option<usize> {
+        self.grammar.as_ref().map(|g| g.id)
+    }
+
     pub fn with_highlights_query(mut self, source: &str) -> Result<Self> {
         let grammar = self.grammar_mut();
         grammar.highlights_query = Some(Query::new(grammar.ts_language, source)?);
@@ -895,6 +900,20 @@ impl Language {
     }
 }
 
+impl Hash for Language {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.id().hash(state)
+    }
+}
+
+impl PartialEq for Language {
+    fn eq(&self, other: &Self) -> bool {
+        self.id().eq(&other.id())
+    }
+}
+
+impl Eq for Language {}
+
 impl Debug for Language {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("Language")

crates/language/src/syntax_map.rs 🔗

@@ -1,4 +1,5 @@
 use crate::{Grammar, InjectionConfig, Language, LanguageRegistry};
+use collections::HashMap;
 use lazy_static::lazy_static;
 use parking_lot::Mutex;
 use std::{
@@ -90,6 +91,7 @@ struct SyntaxLayer {
     range: Range<Anchor>,
     tree: tree_sitter::Tree,
     language: Arc<Language>,
+    combined: bool,
 }
 
 #[derive(Debug)]
@@ -105,22 +107,39 @@ struct SyntaxLayerSummary {
     max_depth: usize,
     range: Range<Anchor>,
     last_layer_range: Range<Anchor>,
+    last_layer_language: Option<usize>,
 }
 
 #[derive(Clone, Debug)]
-struct DepthAndRange(usize, Range<Anchor>);
+struct SyntaxLayerPosition {
+    depth: usize,
+    range: Range<Anchor>,
+    language: Option<usize>,
+}
 
 #[derive(Clone, Debug)]
 struct DepthAndMaxPosition(usize, Anchor);
 
 #[derive(Clone, Debug)]
-struct DepthAndRangeOrMaxPosition(DepthAndRange, DepthAndMaxPosition);
+struct SyntaxLayerPositionBeforeChange {
+    position: SyntaxLayerPosition,
+    change: DepthAndMaxPosition,
+}
 
 struct ReparseStep {
     depth: usize,
     language: Arc<Language>,
-    ranges: Vec<tree_sitter::Range>,
     range: Range<Anchor>,
+    included_ranges: Vec<tree_sitter::Range>,
+    mode: ReparseMode,
+}
+
+enum ReparseMode {
+    Single,
+    Combined {
+        parent_layer_range: Range<usize>,
+        parent_layer_changed_ranges: Vec<Range<usize>>,
+    },
 }
 
 #[derive(Debug, PartialEq, Eq)]
@@ -225,7 +244,11 @@ impl SyntaxSnapshot {
             // subsequent layers at this same depth.
             else if cursor.item().is_some() {
                 let slice = cursor.slice(
-                    &DepthAndRange(depth + 1, Anchor::MIN..Anchor::MAX),
+                    &SyntaxLayerPosition {
+                        depth: depth + 1,
+                        range: Anchor::MIN..Anchor::MAX,
+                        language: None,
+                    },
                     Bias::Left,
                     text,
                 );
@@ -320,28 +343,44 @@ impl SyntaxSnapshot {
 
         let mut changed_regions = ChangeRegionSet::default();
         let mut queue = BinaryHeap::new();
+        let mut combined_injection_ranges = HashMap::default();
         queue.push(ReparseStep {
             depth: 0,
             language: language.clone(),
-            ranges: Vec::new(),
+            included_ranges: vec![tree_sitter::Range {
+                start_byte: 0,
+                end_byte: text.len(),
+                start_point: Point::zero().to_ts_point(),
+                end_point: text.max_point().to_ts_point(),
+            }],
             range: Anchor::MIN..Anchor::MAX,
+            mode: ReparseMode::Single,
         });
 
         loop {
             let step = queue.pop();
-            let (depth, range) = if let Some(step) = &step {
-                (step.depth, step.range.clone())
+            let target = if let Some(step) = &step {
+                SyntaxLayerPosition {
+                    depth: step.depth,
+                    range: step.range.clone(),
+                    language: step.language.id(),
+                }
             } else {
-                (max_depth + 1, Anchor::MAX..Anchor::MAX)
+                SyntaxLayerPosition {
+                    depth: max_depth + 1,
+                    range: Anchor::MAX..Anchor::MAX,
+                    language: None,
+                }
             };
 
-            let target = DepthAndRange(depth, range.clone());
             let mut done = cursor.item().is_none();
             while !done && target.cmp(&cursor.end(text), &text).is_gt() {
                 done = true;
 
-                let bounded_target =
-                    DepthAndRangeOrMaxPosition(target.clone(), changed_regions.start_position());
+                let bounded_target = SyntaxLayerPositionBeforeChange {
+                    position: target.clone(),
+                    change: changed_regions.start_position(),
+                };
                 if bounded_target.cmp(&cursor.start(), &text).is_gt() {
                     let slice = cursor.slice(&bounded_target, Bias::Left, text);
                     if !slice.is_empty() {
@@ -353,11 +392,7 @@ impl SyntaxSnapshot {
                 }
 
                 while target.cmp(&cursor.end(text), text).is_gt() {
-                    let layer = if let Some(layer) = cursor.item() {
-                        layer
-                    } else {
-                        break;
-                    };
+                    let Some(layer) = cursor.item() else { break };
 
                     if changed_regions.intersects(&layer, text) {
                         changed_regions.insert(
@@ -378,70 +413,79 @@ impl SyntaxSnapshot {
                 }
             }
 
-            let (ranges, language) = if let Some(step) = step {
-                (step.ranges, step.language)
-            } else {
-                break;
-            };
-
-            let start_point;
-            let start_byte;
-            let end_byte;
-            if let Some((first, last)) = ranges.first().zip(ranges.last()) {
-                start_point = first.start_point;
-                start_byte = first.start_byte;
-                end_byte = last.end_byte;
-            } else {
-                start_point = Point::zero().to_ts_point();
-                start_byte = 0;
-                end_byte = text.len();
-            };
+            let Some(step) = step else { break };
+            let (step_start_byte, step_start_point) =
+                step.range.start.summary::<(usize, Point)>(text);
+            let step_end_byte = step.range.end.to_offset(text);
+            let Some(grammar) = step.language.grammar.as_deref() else { continue };
 
             let mut old_layer = cursor.item();
             if let Some(layer) = old_layer {
-                if layer.range.to_offset(text) == (start_byte..end_byte) {
+                if layer.range.to_offset(text) == (step_start_byte..step_end_byte)
+                    && layer.language.id() == step.language.id()
+                {
                     cursor.next(&text);
                 } else {
                     old_layer = None;
                 }
             }
 
-            let grammar = if let Some(grammar) = language.grammar.as_deref() {
-                grammar
-            } else {
-                continue;
-            };
+            let mut combined = false;
+            let mut included_ranges = step.included_ranges;
 
             let tree;
             let changed_ranges;
             if let Some(old_layer) = old_layer {
+                if let ReparseMode::Combined {
+                    parent_layer_changed_ranges,
+                    ..
+                } = step.mode
+                {
+                    combined = true;
+                    included_ranges = splice_included_ranges(
+                        old_layer.tree.included_ranges(),
+                        &parent_layer_changed_ranges,
+                        &included_ranges,
+                    );
+                }
+
                 tree = parse_text(
                     grammar,
                     text.as_rope(),
+                    step_start_byte,
+                    step_start_point,
+                    included_ranges,
                     Some(old_layer.tree.clone()),
-                    ranges,
                 );
                 changed_ranges = join_ranges(
                     edits
                         .iter()
                         .map(|e| e.new.clone())
-                        .filter(|range| range.start < end_byte && range.end > start_byte),
+                        .filter(|range| range.start < step_end_byte && range.end > step_start_byte),
                     old_layer
                         .tree
                         .changed_ranges(&tree)
-                        .map(|r| start_byte + r.start_byte..start_byte + r.end_byte),
+                        .map(|r| step_start_byte + r.start_byte..step_start_byte + r.end_byte),
                 );
             } else {
-                tree = parse_text(grammar, text.as_rope(), None, ranges);
-                changed_ranges = vec![start_byte..end_byte];
+                tree = parse_text(
+                    grammar,
+                    text.as_rope(),
+                    step_start_byte,
+                    step_start_point,
+                    included_ranges,
+                    None,
+                );
+                changed_ranges = vec![step_start_byte..step_end_byte];
             }
 
             layers.push(
                 SyntaxLayer {
-                    depth,
-                    range,
+                    depth: step.depth,
+                    range: step.range,
                     tree: tree.clone(),
                     language: language.clone(),
+                    combined,
                 },
                 &text,
             );
@@ -450,11 +494,10 @@ impl SyntaxSnapshot {
                 grammar.injection_config.as_ref().zip(registry.as_ref()),
                 changed_ranges.is_empty(),
             ) {
-                let depth = depth + 1;
                 for range in &changed_ranges {
                     changed_regions.insert(
                         ChangedRegion {
-                            depth,
+                            depth: step.depth + 1,
                             range: text.anchor_before(range.start)..text.anchor_after(range.end),
                         },
                         text,
@@ -463,10 +506,11 @@ impl SyntaxSnapshot {
                 get_injections(
                     config,
                     text,
-                    tree.root_node_with_offset(start_byte, start_point),
+                    tree.root_node_with_offset(step_start_byte, step_start_point.to_ts_point()),
                     registry,
-                    depth,
+                    step.depth + 1,
                     &changed_ranges,
+                    &mut combined_injection_ranges,
                     &mut queue,
                 );
             }
@@ -547,7 +591,6 @@ impl SyntaxSnapshot {
             }
         });
 
-        // let mut result = Vec::new();
         cursor.next(buffer);
         std::iter::from_fn(move || {
             if let Some(layer) = cursor.item() {
@@ -565,8 +608,6 @@ impl SyntaxSnapshot {
                 None
             }
         })
-
-        // result
     }
 }
 
@@ -892,14 +933,11 @@ fn join_ranges(
 fn parse_text(
     grammar: &Grammar,
     text: &Rope,
-    old_tree: Option<Tree>,
+    start_byte: usize,
+    start_point: Point,
     mut ranges: Vec<tree_sitter::Range>,
+    old_tree: Option<Tree>,
 ) -> Tree {
-    let (start_byte, start_point) = ranges
-        .first()
-        .map(|range| (range.start_byte, Point::from_ts_point(range.start_point)))
-        .unwrap_or_default();
-
     for range in &mut ranges {
         range.start_byte -= start_byte;
         range.end_byte -= start_byte;
@@ -934,13 +972,16 @@ fn get_injections(
     node: Node,
     language_registry: &LanguageRegistry,
     depth: usize,
-    query_ranges: &[Range<usize>],
+    changed_ranges: &[Range<usize>],
+    combined_injection_ranges: &mut HashMap<Arc<Language>, Vec<tree_sitter::Range>>,
     queue: &mut BinaryHeap<ReparseStep>,
 ) -> bool {
     let mut result = false;
     let mut query_cursor = QueryCursorHandle::new();
     let mut prev_match = None;
-    for query_range in query_ranges {
+
+    combined_injection_ranges.clear();
+    for query_range in changed_ranges {
         query_cursor.set_byte_range(query_range.start.saturating_sub(1)..query_range.end);
         for mat in query_cursor.matches(&config.query, node, TextProvider(text.as_rope())) {
             let content_ranges = mat
@@ -961,7 +1002,9 @@ fn get_injections(
             }
             prev_match = Some((mat.pattern_index, content_range.clone()));
 
-            let language_name = config.patterns[mat.pattern_index].language
+            let combined = config.patterns[mat.pattern_index].combined;
+            let language_name = config.patterns[mat.pattern_index]
+                .language
                 .as_ref()
                 .map(|s| Cow::Borrowed(s.as_ref()))
                 .or_else(|| {
@@ -975,19 +1018,93 @@ fn get_injections(
                     result = true;
                     let range = text.anchor_before(content_range.start)
                         ..text.anchor_after(content_range.end);
-                    queue.push(ReparseStep {
-                        depth,
-                        language,
-                        ranges: content_ranges,
-                        range,
-                    })
+                    if combined {
+                        combined_injection_ranges
+                            .entry(language.clone())
+                            .or_default()
+                            .extend(content_ranges);
+                    } else {
+                        queue.push(ReparseStep {
+                            depth,
+                            language,
+                            included_ranges: content_ranges,
+                            range,
+                            mode: ReparseMode::Single,
+                        });
+                    }
                 }
             }
         }
     }
+
+    for (language, mut included_ranges) in combined_injection_ranges.drain() {
+        included_ranges.sort_unstable();
+        let range = text.anchor_before(node.start_byte())..text.anchor_after(node.end_byte());
+        queue.push(ReparseStep {
+            depth,
+            language,
+            range,
+            included_ranges,
+            mode: ReparseMode::Combined {
+                parent_layer_range: node.start_byte()..node.end_byte(),
+                parent_layer_changed_ranges: changed_ranges.to_vec(),
+            },
+        })
+    }
+
     result
 }
 
+fn splice_included_ranges(
+    mut ranges: Vec<tree_sitter::Range>,
+    changed_ranges: &[Range<usize>],
+    new_ranges: &[tree_sitter::Range],
+) -> Vec<tree_sitter::Range> {
+    let mut changed_ranges = changed_ranges.into_iter().peekable();
+    let mut new_ranges = new_ranges.into_iter().peekable();
+    let mut ranges_ix = 0;
+    loop {
+        let new_range = new_ranges.peek();
+        let mut changed_range = changed_ranges.peek();
+
+        // process changed ranges before any overlapping new ranges
+        if let Some((changed, new)) = changed_range.zip(new_range) {
+            if new.end_byte < changed.start {
+                changed_range = None;
+            }
+        }
+
+        if let Some(changed) = changed_range {
+            let start_ix = ranges_ix
+                + match ranges[ranges_ix..].binary_search_by_key(&changed.start, |r| r.end_byte) {
+                    Ok(ix) | Err(ix) => ix,
+                };
+            let end_ix = ranges_ix
+                + match ranges[ranges_ix..].binary_search_by_key(&changed.end, |r| r.start_byte) {
+                    Ok(ix) | Err(ix) => ix,
+                };
+            if end_ix > start_ix {
+                ranges.splice(start_ix..end_ix, []);
+            }
+            changed_ranges.next();
+            ranges_ix = start_ix;
+        } else if let Some(new_range) = new_range {
+            let ix = ranges_ix
+                + match ranges[ranges_ix..]
+                    .binary_search_by_key(&new_range.start_byte, |r| r.start_byte)
+                {
+                    Ok(ix) | Err(ix) => ix,
+                };
+            ranges.insert(ix, **new_range);
+            new_ranges.next();
+            ranges_ix = ix + 1;
+        } else {
+            break;
+        }
+    }
+    ranges
+}
+
 impl std::ops::Deref for SyntaxMap {
     type Target = SyntaxSnapshot;
 
@@ -1017,14 +1134,22 @@ impl Ord for ReparseStep {
         Ord::cmp(&other.depth, &self.depth)
             .then_with(|| Ord::cmp(&range_b.start, &range_a.start))
             .then_with(|| Ord::cmp(&range_a.end, &range_b.end))
+            .then_with(|| self.language.id().cmp(&other.language.id()))
     }
 }
 
 impl ReparseStep {
     fn range(&self) -> Range<usize> {
-        let start = self.ranges.first().map_or(0, |r| r.start_byte);
-        let end = self.ranges.last().map_or(0, |r| r.end_byte);
-        start..end
+        if let ReparseMode::Combined {
+            parent_layer_range, ..
+        } = &self.mode
+        {
+            parent_layer_range.clone()
+        } else {
+            let start = self.included_ranges.first().map_or(0, |r| r.start_byte);
+            let end = self.included_ranges.last().map_or(0, |r| r.end_byte);
+            start..end
+        }
     }
 }
 
@@ -1094,6 +1219,7 @@ impl Default for SyntaxLayerSummary {
             min_depth: 0,
             range: Anchor::MAX..Anchor::MIN,
             last_layer_range: Anchor::MIN..Anchor::MAX,
+            last_layer_language: None,
         }
     }
 }
@@ -1114,14 +1240,15 @@ impl sum_tree::Summary for SyntaxLayerSummary {
             }
         }
         self.last_layer_range = other.last_layer_range.clone();
+        self.last_layer_language = other.last_layer_language;
     }
 }
 
-impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRange {
+impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for SyntaxLayerPosition {
     fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
-        Ord::cmp(&self.0, &cursor_location.max_depth)
+        Ord::cmp(&self.depth, &cursor_location.max_depth)
             .then_with(|| {
-                self.1
+                self.range
                     .start
                     .cmp(&cursor_location.last_layer_range.start, buffer)
             })
@@ -1129,8 +1256,9 @@ impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRang
                 cursor_location
                     .last_layer_range
                     .end
-                    .cmp(&self.1.end, buffer)
+                    .cmp(&self.range.end, buffer)
             })
+            .then_with(|| self.language.cmp(&cursor_location.last_layer_language))
     }
 }
 
@@ -1141,12 +1269,14 @@ impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndMaxP
     }
 }
 
-impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRangeOrMaxPosition {
+impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary>
+    for SyntaxLayerPositionBeforeChange
+{
     fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
-        if self.1.cmp(cursor_location, buffer).is_le() {
+        if self.change.cmp(cursor_location, buffer).is_le() {
             return Ordering::Less;
         } else {
-            self.0.cmp(cursor_location, buffer)
+            self.position.cmp(cursor_location, buffer)
         }
     }
 }
@@ -1160,6 +1290,7 @@ impl sum_tree::Item for SyntaxLayer {
             max_depth: self.depth,
             range: self.range.clone(),
             last_layer_range: self.range.clone(),
+            last_layer_language: self.language.id(),
         }
     }
 }
@@ -1246,6 +1377,60 @@ mod tests {
     use unindent::Unindent as _;
     use util::test::marked_text_ranges;
 
+    #[test]
+    fn test_splice_included_ranges() {
+        let ranges = vec![ts_range(20..30), ts_range(50..60), ts_range(80..90)];
+
+        let new_ranges = splice_included_ranges(
+            ranges.clone(),
+            &[54..56, 58..68],
+            &[ts_range(50..54), ts_range(59..67)],
+        );
+        assert_eq!(
+            new_ranges,
+            &[
+                ts_range(20..30),
+                ts_range(50..54),
+                ts_range(59..67),
+                ts_range(80..90),
+            ]
+        );
+
+        let new_ranges = splice_included_ranges(ranges.clone(), &[70..71, 91..100], &[]);
+        assert_eq!(
+            new_ranges,
+            &[ts_range(20..30), ts_range(50..60), ts_range(80..90)]
+        );
+
+        let new_ranges =
+            splice_included_ranges(ranges.clone(), &[], &[ts_range(0..2), ts_range(70..75)]);
+        assert_eq!(
+            new_ranges,
+            &[
+                ts_range(0..2),
+                ts_range(20..30),
+                ts_range(50..60),
+                ts_range(70..75),
+                ts_range(80..90)
+            ]
+        );
+
+        fn ts_range(range: Range<usize>) -> tree_sitter::Range {
+            tree_sitter::Range {
+                start_byte: range.start,
+                start_point: tree_sitter::Point {
+                    row: 0,
+                    column: range.start,
+                },
+                end_byte: range.end,
+                end_point: tree_sitter::Point {
+                    row: 0,
+                    column: range.end,
+                },
+            }
+        }
+    }
+
     #[gpui::test]
     fn test_syntax_map_layers_for_range() {
         let registry = Arc::new(LanguageRegistry::test());