Add SyntaxMap methods for running queries and combining their results

Max Brunsfeld created

Change summary

crates/language/src/syntax_map.rs | 241 ++++++++++++++++++++++++++++++--
1 file changed, 224 insertions(+), 17 deletions(-)

Detailed changes

crates/language/src/syntax_map.rs 🔗

@@ -3,11 +3,19 @@ use crate::{
     ToTreeSitterPoint,
 };
 use std::{
-    borrow::Cow, cell::RefCell, cmp::Ordering, collections::BinaryHeap, ops::Range, sync::Arc,
+    borrow::Cow,
+    cell::RefCell,
+    cmp::{Ordering, Reverse},
+    collections::BinaryHeap,
+    iter::Peekable,
+    ops::{DerefMut, Range},
+    sync::Arc,
 };
 use sum_tree::{Bias, SeekTarget, SumTree};
 use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint};
-use tree_sitter::{Node, Parser, Tree};
+use tree_sitter::{
+    Node, Parser, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatch, QueryMatches, Tree,
+};
 
 thread_local! {
     static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
@@ -26,6 +34,42 @@ pub struct SyntaxSnapshot {
     layers: SumTree<SyntaxLayer>,
 }
 
+pub struct SyntaxMapCaptures<'a> {
+    layers: Vec<SyntaxMapCapturesLayer<'a>>,
+}
+
+pub struct SyntaxMapMatches<'a> {
+    layers: Vec<SyntaxMapMatchesLayer<'a>>,
+}
+
+pub struct SyntaxMapCapture<'a> {
+    pub grammar: &'a Grammar,
+    pub depth: usize,
+    pub node: Node<'a>,
+    pub index: u32,
+}
+
+pub struct SyntaxMapMatch<'a> {
+    pub grammar: &'a Grammar,
+    pub depth: usize,
+    pub pattern_index: usize,
+    pub captures: &'a [QueryCapture<'a>],
+}
+
+struct SyntaxMapCapturesLayer<'a> {
+    depth: usize,
+    captures: Peekable<QueryCaptures<'a, 'a, TextProvider<'a>>>,
+    grammar: &'a Grammar,
+    _query_cursor: QueryCursorHandle,
+}
+
+struct SyntaxMapMatchesLayer<'a> {
+    depth: usize,
+    matches: Peekable<QueryMatches<'a, 'a, TextProvider<'a>>>,
+    grammar: &'a Grammar,
+    _query_cursor: QueryCursorHandle,
+}
+
 #[derive(Clone)]
 struct SyntaxLayer {
     depth: usize,
@@ -385,6 +429,100 @@ impl SyntaxSnapshot {
         self.layers = layers;
     }
 
+    pub fn captures<'a>(
+        &'a self,
+        range: Range<usize>,
+        buffer: &'a BufferSnapshot,
+        query: impl Fn(&Grammar) -> Option<&Query>,
+    ) -> SyntaxMapCaptures {
+        let mut result = SyntaxMapCaptures { layers: Vec::new() };
+        for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) {
+            let query = if let Some(query) = query(grammar) {
+                query
+            } else {
+                continue;
+            };
+
+            let mut query_cursor = QueryCursorHandle::new();
+
+            // TODO - add a Tree-sitter API to remove the need for this.
+            let cursor = unsafe {
+                std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut())
+            };
+
+            cursor.set_byte_range(range.clone());
+            let captures = cursor.captures(query, node, TextProvider(buffer.as_rope()));
+            let mut layer = SyntaxMapCapturesLayer {
+                depth,
+                grammar,
+                captures: captures.peekable(),
+                _query_cursor: query_cursor,
+            };
+
+            if let Some(key) = layer.sort_key() {
+                let mut ix = 0;
+                while let Some(next_layer) = result.layers.get_mut(ix) {
+                    if let Some(next_key) = next_layer.sort_key() {
+                        if key > next_key {
+                            ix += 1;
+                            continue;
+                        }
+                    }
+                    break;
+                }
+                result.layers.insert(ix, layer);
+            }
+        }
+        result
+    }
+
+    pub fn matches<'a>(
+        &'a self,
+        range: Range<usize>,
+        buffer: &'a BufferSnapshot,
+        query: impl Fn(&Grammar) -> Option<&Query>,
+    ) -> SyntaxMapMatches {
+        let mut result = SyntaxMapMatches { layers: Vec::new() };
+        for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) {
+            let query = if let Some(query) = query(grammar) {
+                query
+            } else {
+                continue;
+            };
+
+            let mut query_cursor = QueryCursorHandle::new();
+
+            // TODO - add a Tree-sitter API to remove the need for this.
+            let cursor = unsafe {
+                std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut())
+            };
+
+            cursor.set_byte_range(range.clone());
+            let matches = cursor.matches(query, node, TextProvider(buffer.as_rope()));
+            let mut layer = SyntaxMapMatchesLayer {
+                depth,
+                grammar,
+                matches: matches.peekable(),
+                _query_cursor: query_cursor,
+            };
+
+            if let Some(key) = layer.sort_key() {
+                let mut ix = 0;
+                while let Some(next_layer) = result.layers.get_mut(ix) {
+                    if let Some(next_key) = next_layer.sort_key() {
+                        if key > next_key {
+                            ix += 1;
+                            continue;
+                        }
+                    }
+                    break;
+                }
+                result.layers.insert(ix, layer);
+            }
+        }
+        result
+    }
+
     pub fn layers(&self, buffer: &BufferSnapshot) -> Vec<(&Grammar, Node)> {
         self.layers
             .iter()
@@ -408,7 +546,7 @@ impl SyntaxSnapshot {
         &self,
         range: Range<T>,
         buffer: &BufferSnapshot,
-    ) -> Vec<(&Grammar, Node)> {
+    ) -> Vec<(&Grammar, usize, Node)> {
         let start = buffer.anchor_before(range.start.to_offset(buffer));
         let end = buffer.anchor_after(range.end.to_offset(buffer));
 
@@ -424,6 +562,7 @@ impl SyntaxSnapshot {
             if let Some(grammar) = &layer.language.grammar {
                 result.push((
                     grammar.as_ref(),
+                    layer.depth,
                     layer.tree.root_node_with_offset(
                         layer.range.start.to_offset(buffer),
                         layer.range.start.to_point(buffer).to_ts_point(),
@@ -437,6 +576,60 @@ impl SyntaxSnapshot {
     }
 }
 
+impl<'a> Iterator for SyntaxMapCaptures<'a> {
+    type Item = SyntaxMapCapture<'a>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let layer = self.layers.first_mut()?;
+        let (mat, ix) = layer.captures.next()?;
+
+        let capture = mat.captures[ix as usize];
+        let grammar = layer.grammar;
+        let depth = layer.depth;
+
+        if let Some(key) = layer.sort_key() {
+            let mut i = 1;
+            while let Some(later_layer) = self.layers.get_mut(i) {
+                if let Some(later_key) = later_layer.sort_key() {
+                    if key > later_key {
+                        i += 1;
+                        continue;
+                    }
+                }
+                break;
+            }
+            if i > 1 {
+                self.layers[0..i].rotate_left(1);
+            }
+        } else {
+            self.layers.remove(0);
+        }
+
+        Some(SyntaxMapCapture {
+            grammar,
+            depth,
+            node: capture.node,
+            index: capture.index,
+        })
+    }
+}
+
+impl<'a> SyntaxMapCapturesLayer<'a> {
+    fn sort_key(&mut self) -> Option<(usize, Reverse<usize>, usize)> {
+        let (mat, ix) = self.captures.peek()?;
+        let range = &mat.captures[*ix].node.byte_range();
+        Some((range.start, Reverse(range.end), self.depth))
+    }
+}
+
+impl<'a> SyntaxMapMatchesLayer<'a> {
+    fn sort_key(&mut self) -> Option<(usize, Reverse<usize>, usize)> {
+        let mat = self.matches.peek()?;
+        let range = mat.captures.first()?.node.start_byte()..mat.captures.last()?.node.end_byte();
+        Some((range.start, Reverse(range.end), self.depth))
+    }
+}
+
 fn join_ranges(
     a: impl Iterator<Item = Range<usize>>,
     b: impl Iterator<Item = Range<usize>>,
@@ -875,10 +1068,10 @@ mod tests {
             "fn a() { dbg!(b.c(vec![d.«e»])) }",
         ]);
 
-        assert_node_ranges(
+        assert_capture_ranges(
             &syntax_map,
             &buffer,
-            "(field_identifier) @_",
+            &["field"],
             "fn a() { dbg!(b.«c»(vec![d.«e»])) }",
         );
     }
@@ -909,10 +1102,10 @@ mod tests {
             ",
         ]);
 
-        assert_node_ranges(
+        assert_capture_ranges(
             &syntax_map,
             &buffer,
-            "(struct_expression) @_",
+            &["struct"],
             "
             fn a() {
                 b!(«B {}»);
@@ -952,10 +1145,10 @@ mod tests {
             ",
         ]);
 
-        assert_node_ranges(
+        assert_capture_ranges(
             &syntax_map,
             &buffer,
-            "(field_identifier) @_",
+            &["field"],
             "
             fn a() {
                 b!(
@@ -1129,6 +1322,13 @@ mod tests {
             },
             Some(tree_sitter_rust::language()),
         )
+        .with_highlights_query(
+            r#"
+                (field_identifier) @field
+                (struct_expression) @struct
+            "#,
+        )
+        .unwrap()
         .with_injection_query(
             r#"
                 (macro_invocation
@@ -1156,7 +1356,7 @@ mod tests {
             expected_layers.len(),
             "wrong number of layers"
         );
-        for (i, ((_, node), expected_s_exp)) in
+        for (i, ((_, _, node), expected_s_exp)) in
             layers.iter().zip(expected_layers.iter()).enumerate()
         {
             let actual_s_exp = node.to_sexp();
@@ -1170,18 +1370,25 @@ mod tests {
         }
     }
 
-    fn assert_node_ranges(
+    fn assert_capture_ranges(
         syntax_map: &SyntaxMap,
         buffer: &BufferSnapshot,
-        query: &str,
+        highlight_query_capture_names: &[&str],
         marked_string: &str,
     ) {
-        let mut cursor = QueryCursorHandle::new();
         let mut actual_ranges = Vec::<Range<usize>>::new();
-        for (grammar, node) in syntax_map.layers(buffer) {
-            let query = Query::new(grammar.ts_language, query).unwrap();
-            for (mat, ix) in cursor.captures(&query, node, TextProvider(buffer.as_rope())) {
-                actual_ranges.push(mat.captures[ix].node.byte_range());
+        for capture in syntax_map.captures(0..buffer.len(), buffer, |grammar| {
+            grammar.highlights_query.as_ref()
+        }) {
+            let name = &capture
+                .grammar
+                .highlights_query
+                .as_ref()
+                .unwrap()
+                .capture_names()[capture.index as usize];
+            dbg!(capture.node, capture.index, name);
+            if highlight_query_capture_names.contains(&name.as_str()) {
+                actual_ranges.push(capture.node.byte_range());
             }
         }