Fix bugs in handling combined injections exposed by HEEx (#2652)

Max Brunsfeld created

Fixes
https://linear.app/zed-industries/issue/Z-2481/heex-this-snippet-triggers-a-hard-crash

Release Notes:

- Fixed a crash that would sometimes occur when editing a HEEx file
([#1703](https://github.com/zed-industries/community/issues/1703)).

Change summary

crates/language/src/syntax_map.rs                  | 239 +++++++++++----
crates/language/src/syntax_map/syntax_map_tests.rs |  33 ++
crates/zed/src/languages/heex/injections.scm       |   2 
3 files changed, 197 insertions(+), 77 deletions(-)

Detailed changes

crates/language/src/syntax_map.rs πŸ”—

@@ -11,7 +11,7 @@ use std::{
     cell::RefCell,
     cmp::{self, Ordering, Reverse},
     collections::BinaryHeap,
-    iter,
+    fmt, iter,
     ops::{Deref, DerefMut, Range},
     sync::Arc,
 };
@@ -428,6 +428,8 @@ impl SyntaxSnapshot {
         invalidated_ranges: Vec<Range<usize>>,
         registry: Option<&Arc<LanguageRegistry>>,
     ) {
+        log::trace!("reparse. invalidated ranges:{:?}", invalidated_ranges);
+
         let max_depth = self.layers.summary().max_depth;
         let mut cursor = self.layers.cursor::<SyntaxLayerSummary>();
         cursor.next(&text);
@@ -489,6 +491,15 @@ impl SyntaxSnapshot {
                     let Some(layer) = cursor.item() else { break };
 
                     if changed_regions.intersects(&layer, text) {
+                        if let SyntaxLayerContent::Parsed { language, .. } = &layer.content {
+                            log::trace!(
+                                "discard layer. language:{}, range:{:?}. changed_regions:{:?}",
+                                language.name(),
+                                LogAnchorRange(&layer.range, text),
+                                LogChangedRegions(&changed_regions, text),
+                            );
+                        }
+
                         changed_regions.insert(
                             ChangedRegion {
                                 depth: layer.depth + 1,
@@ -541,26 +552,24 @@ impl SyntaxSnapshot {
                             .to_ts_point();
                     }
 
-                    if included_ranges.is_empty() {
-                        included_ranges.push(tree_sitter::Range {
-                            start_byte: 0,
-                            end_byte: 0,
-                            start_point: Default::default(),
-                            end_point: Default::default(),
-                        });
-                    }
-
-                    if let Some(SyntaxLayerContent::Parsed { tree: old_tree, .. }) =
-                        old_layer.map(|layer| &layer.content)
+                    if let Some((SyntaxLayerContent::Parsed { tree: old_tree, .. }, layer_start)) =
+                        old_layer.map(|layer| (&layer.content, layer.range.start))
                     {
+                        log::trace!(
+                            "existing layer. language:{}, start:{:?}, ranges:{:?}",
+                            language.name(),
+                            LogPoint(layer_start.to_point(&text)),
+                            LogIncludedRanges(&old_tree.included_ranges())
+                        );
+
                         if let ParseMode::Combined {
                             mut parent_layer_changed_ranges,
                             ..
                         } = step.mode
                         {
                             for range in &mut parent_layer_changed_ranges {
-                                range.start -= step_start_byte;
-                                range.end -= step_start_byte;
+                                range.start = range.start.saturating_sub(step_start_byte);
+                                range.end = range.end.saturating_sub(step_start_byte);
                             }
 
                             included_ranges = splice_included_ranges(
@@ -570,6 +579,22 @@ impl SyntaxSnapshot {
                             );
                         }
 
+                        if included_ranges.is_empty() {
+                            included_ranges.push(tree_sitter::Range {
+                                start_byte: 0,
+                                end_byte: 0,
+                                start_point: Default::default(),
+                                end_point: Default::default(),
+                            });
+                        }
+
+                        log::trace!(
+                            "update layer. language:{}, start:{:?}, ranges:{:?}",
+                            language.name(),
+                            LogAnchorRange(&step.range, text),
+                            LogIncludedRanges(&included_ranges),
+                        );
+
                         tree = parse_text(
                             grammar,
                             text.as_rope(),
@@ -586,6 +611,22 @@ impl SyntaxSnapshot {
                             }),
                         );
                     } else {
+                        if included_ranges.is_empty() {
+                            included_ranges.push(tree_sitter::Range {
+                                start_byte: 0,
+                                end_byte: 0,
+                                start_point: Default::default(),
+                                end_point: Default::default(),
+                            });
+                        }
+
+                        log::trace!(
+                            "create layer. language:{}, range:{:?}, included_ranges:{:?}",
+                            language.name(),
+                            LogAnchorRange(&step.range, text),
+                            LogIncludedRanges(&included_ranges),
+                        );
+
                         tree = parse_text(
                             grammar,
                             text.as_rope(),
@@ -613,6 +654,7 @@ impl SyntaxSnapshot {
                         get_injections(
                             config,
                             text,
+                            step.range.clone(),
                             tree.root_node_with_offset(
                                 step_start_byte,
                                 step_start_point.to_ts_point(),
@@ -1117,6 +1159,7 @@ fn parse_text(
 fn get_injections(
     config: &InjectionConfig,
     text: &BufferSnapshot,
+    outer_range: Range<Anchor>,
     node: Node,
     language_registry: &Arc<LanguageRegistry>,
     depth: usize,
@@ -1153,16 +1196,17 @@ fn get_injections(
                 continue;
             }
 
-            // Avoid duplicate matches if two changed ranges intersect the same injection.
             let content_range =
                 content_ranges.first().unwrap().start_byte..content_ranges.last().unwrap().end_byte;
-            if let Some((last_pattern_ix, last_range)) = &prev_match {
-                if mat.pattern_index == *last_pattern_ix && content_range == *last_range {
+
+            // Avoid duplicate matches if two changed ranges intersect the same injection.
+            if let Some((prev_pattern_ix, prev_range)) = &prev_match {
+                if mat.pattern_index == *prev_pattern_ix && content_range == *prev_range {
                     continue;
                 }
             }
-            prev_match = Some((mat.pattern_index, content_range.clone()));
 
+            prev_match = Some((mat.pattern_index, content_range.clone()));
             let combined = config.patterns[mat.pattern_index].combined;
 
             let mut language_name = None;
@@ -1218,11 +1262,10 @@ fn get_injections(
 
     for (language, mut included_ranges) in combined_injection_ranges.drain() {
         included_ranges.sort_unstable();
-        let range = text.anchor_before(node.start_byte())..text.anchor_after(node.end_byte());
         queue.push(ParseStep {
             depth,
             language: ParseStepLanguage::Loaded { language },
-            range,
+            range: outer_range.clone(),
             included_ranges,
             mode: ParseMode::Combined {
                 parent_layer_range: node.start_byte()..node.end_byte(),
@@ -1234,72 +1277,77 @@ fn get_injections(
 
 pub(crate) fn splice_included_ranges(
     mut ranges: Vec<tree_sitter::Range>,
-    changed_ranges: &[Range<usize>],
+    removed_ranges: &[Range<usize>],
     new_ranges: &[tree_sitter::Range],
 ) -> Vec<tree_sitter::Range> {
-    let mut changed_ranges = changed_ranges.into_iter().peekable();
-    let mut new_ranges = new_ranges.into_iter().peekable();
+    let mut removed_ranges = removed_ranges.iter().cloned().peekable();
+    let mut new_ranges = new_ranges.into_iter().cloned().peekable();
     let mut ranges_ix = 0;
     loop {
-        let new_range = new_ranges.peek();
-        let mut changed_range = changed_ranges.peek();
-
-        // Remove ranges that have changed before inserting any new ranges
-        // into those ranges.
-        if let Some((changed, new)) = changed_range.zip(new_range) {
-            if new.end_byte < changed.start {
-                changed_range = None;
-            }
-        }
-
-        if let Some(changed) = changed_range {
-            let mut start_ix = ranges_ix
-                + match ranges[ranges_ix..].binary_search_by_key(&changed.start, |r| r.end_byte) {
-                    Ok(ix) | Err(ix) => ix,
-                };
-            let mut end_ix = ranges_ix
-                + match ranges[ranges_ix..].binary_search_by_key(&changed.end, |r| r.start_byte) {
-                    Ok(ix) => ix + 1,
-                    Err(ix) => ix,
-                };
+        let next_new_range = new_ranges.peek();
+        let next_removed_range = removed_ranges.peek();
 
-            // If there are empty ranges, then there may be multiple ranges with the same
-            // start or end. Expand the splice to include any adjacent ranges that touch
-            // the changed range.
-            while start_ix > 0 {
-                if ranges[start_ix - 1].end_byte == changed.start {
-                    start_ix -= 1;
-                } else {
-                    break;
-                }
-            }
-            while let Some(range) = ranges.get(end_ix) {
-                if range.start_byte == changed.end {
-                    end_ix += 1;
+        let (remove, insert) = match (next_removed_range, next_new_range) {
+            (None, None) => break,
+            (Some(_), None) => (removed_ranges.next().unwrap(), None),
+            (Some(next_removed_range), Some(next_new_range)) => {
+                if next_removed_range.end < next_new_range.start_byte {
+                    (removed_ranges.next().unwrap(), None)
                 } else {
-                    break;
+                    let mut start = next_new_range.start_byte;
+                    let mut end = next_new_range.end_byte;
+
+                    while let Some(next_removed_range) = removed_ranges.peek() {
+                        if next_removed_range.start > next_new_range.end_byte {
+                            break;
+                        }
+                        let next_removed_range = removed_ranges.next().unwrap();
+                        start = cmp::min(start, next_removed_range.start);
+                        end = cmp::max(end, next_removed_range.end);
+                    }
+
+                    (start..end, Some(new_ranges.next().unwrap()))
                 }
             }
+            (None, Some(next_new_range)) => (
+                next_new_range.start_byte..next_new_range.end_byte,
+                Some(new_ranges.next().unwrap()),
+            ),
+        };
 
-            if end_ix > start_ix {
-                ranges.splice(start_ix..end_ix, []);
+        let mut start_ix = ranges_ix
+            + match ranges[ranges_ix..].binary_search_by_key(&remove.start, |r| r.end_byte) {
+                Ok(ix) => ix,
+                Err(ix) => ix,
+            };
+        let mut end_ix = ranges_ix
+            + match ranges[ranges_ix..].binary_search_by_key(&remove.end, |r| r.start_byte) {
+                Ok(ix) => ix + 1,
+                Err(ix) => ix,
+            };
+
+        // If there are empty ranges, then there may be multiple ranges with the same
+        // start or end. Expand the splice to include any adjacent ranges that touch
+        // the changed range.
+        while start_ix > 0 {
+            if ranges[start_ix - 1].end_byte == remove.start {
+                start_ix -= 1;
+            } else {
+                break;
+            }
+        }
+        while let Some(range) = ranges.get(end_ix) {
+            if range.start_byte == remove.end {
+                end_ix += 1;
+            } else {
+                break;
             }
-            changed_ranges.next();
-            ranges_ix = start_ix;
-        } else if let Some(new_range) = new_range {
-            let ix = ranges_ix
-                + match ranges[ranges_ix..]
-                    .binary_search_by_key(&new_range.start_byte, |r| r.start_byte)
-                {
-                    Ok(ix) | Err(ix) => ix,
-                };
-            ranges.insert(ix, **new_range);
-            new_ranges.next();
-            ranges_ix = ix + 1;
-        } else {
-            break;
         }
+
+        ranges.splice(start_ix..end_ix, insert);
+        ranges_ix = start_ix;
     }
+
     ranges
 }
 
@@ -1628,3 +1676,46 @@ impl ToTreeSitterPoint for Point {
         Point::new(point.row as u32, point.column as u32)
     }
 }
+
+struct LogIncludedRanges<'a>(&'a [tree_sitter::Range]);
+struct LogPoint(Point);
+struct LogAnchorRange<'a>(&'a Range<Anchor>, &'a text::BufferSnapshot);
+struct LogChangedRegions<'a>(&'a ChangeRegionSet, &'a text::BufferSnapshot);
+
+impl<'a> fmt::Debug for LogIncludedRanges<'a> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_list()
+            .entries(self.0.iter().map(|range| {
+                let start = range.start_point;
+                let end = range.end_point;
+                (start.row, start.column)..(end.row, end.column)
+            }))
+            .finish()
+    }
+}
+
+impl<'a> fmt::Debug for LogAnchorRange<'a> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        let range = self.0.to_point(self.1);
+        (LogPoint(range.start)..LogPoint(range.end)).fmt(f)
+    }
+}
+
+impl<'a> fmt::Debug for LogChangedRegions<'a> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_list()
+            .entries(
+                self.0
+                     .0
+                    .iter()
+                    .map(|region| LogAnchorRange(&region.range, self.1)),
+            )
+            .finish()
+    }
+}
+
+impl fmt::Debug for LogPoint {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        (self.0.row, self.0.column).fmt(f)
+    }
+}

crates/language/src/syntax_map/syntax_map_tests.rs πŸ”—

@@ -48,6 +48,13 @@ fn test_splice_included_ranges() {
     let new_ranges = splice_included_ranges(ranges.clone(), &[30..50], &[ts_range(25..55)]);
     assert_eq!(new_ranges, &[ts_range(25..55), ts_range(80..90)]);
 
+    // does not create overlapping ranges
+    let new_ranges = splice_included_ranges(ranges.clone(), &[0..18], &[ts_range(20..32)]);
+    assert_eq!(
+        new_ranges,
+        &[ts_range(20..32), ts_range(50..60), ts_range(80..90)]
+    );
+
     fn ts_range(range: Range<usize>) -> tree_sitter::Range {
         tree_sitter::Range {
             start_byte: range.start,
@@ -624,6 +631,26 @@ fn test_combined_injections_splitting_some_injections() {
     );
 }
 
+#[gpui::test]
+fn test_combined_injections_editing_after_last_injection() {
+    test_edit_sequence(
+        "ERB",
+        &[
+            r#"
+                <% foo %>
+                <div></div>
+                <% bar %>
+            "#,
+            r#"
+                <% foo %>
+                <div></div>
+                <% bar %>Β«
+                more textΒ»
+            "#,
+        ],
+    );
+}
+
 #[gpui::test]
 fn test_combined_injections_inside_injections() {
     let (_buffer, _syntax_map) = test_edit_sequence(
@@ -974,13 +1001,16 @@ fn test_edit_sequence(language_name: &str, steps: &[&str]) -> (Buffer, SyntaxMap
     mutated_syntax_map.reparse(language.clone(), &buffer);
 
     for (i, marked_string) in steps.into_iter().enumerate() {
-        buffer.edit_via_marked_text(&marked_string.unindent());
+        let marked_string = marked_string.unindent();
+        log::info!("incremental parse {i}: {marked_string:?}");
+        buffer.edit_via_marked_text(&marked_string);
 
         // Reparse the syntax map
         mutated_syntax_map.interpolate(&buffer);
         mutated_syntax_map.reparse(language.clone(), &buffer);
 
         // Create a second syntax map from scratch
+        log::info!("fresh parse {i}: {marked_string:?}");
         let mut reference_syntax_map = SyntaxMap::new();
         reference_syntax_map.set_language_registry(registry.clone());
         reference_syntax_map.reparse(language.clone(), &buffer);
@@ -1133,6 +1163,7 @@ fn range_for_text(buffer: &Buffer, text: &str) -> Range<usize> {
     start..start + text.len()
 }
 
+#[track_caller]
 fn assert_layers_for_range(
     syntax_map: &SyntaxMap,
     buffer: &BufferSnapshot,

crates/zed/src/languages/heex/injections.scm πŸ”—

@@ -9,7 +9,5 @@
   (#set! combined)
 )
 
-; expressions live within HTML tags, and do not need to be combined
-;     <link href={ Routes.static_path(..) } />
 ((expression (expression_value) @content)
  (#set! language "elixir"))