From 9113c94371430ef07fb412aa766ae77db7e164a9 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 23 Aug 2022 14:26:09 -0700 Subject: [PATCH] Add SyntaxMap methods for running queries and combining their results --- crates/language/src/syntax_map.rs | 241 +++++++++++++++++++++++++++--- 1 file changed, 224 insertions(+), 17 deletions(-) diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index 8800bb5cd885afaa1c5215adb309e4c1398e4094..a578d36a382144dc5bece35448254ab974eb26fa 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -3,11 +3,19 @@ use crate::{ ToTreeSitterPoint, }; use std::{ - borrow::Cow, cell::RefCell, cmp::Ordering, collections::BinaryHeap, ops::Range, sync::Arc, + borrow::Cow, + cell::RefCell, + cmp::{Ordering, Reverse}, + collections::BinaryHeap, + iter::Peekable, + ops::{DerefMut, Range}, + sync::Arc, }; use sum_tree::{Bias, SeekTarget, SumTree}; use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint}; -use tree_sitter::{Node, Parser, Tree}; +use tree_sitter::{ + Node, Parser, Query, QueryCapture, QueryCaptures, QueryCursor, QueryMatch, QueryMatches, Tree, +}; thread_local! { static PARSER: RefCell = RefCell::new(Parser::new()); @@ -26,6 +34,42 @@ pub struct SyntaxSnapshot { layers: SumTree, } +pub struct SyntaxMapCaptures<'a> { + layers: Vec>, +} + +pub struct SyntaxMapMatches<'a> { + layers: Vec>, +} + +pub struct SyntaxMapCapture<'a> { + pub grammar: &'a Grammar, + pub depth: usize, + pub node: Node<'a>, + pub index: u32, +} + +pub struct SyntaxMapMatch<'a> { + pub grammar: &'a Grammar, + pub depth: usize, + pub pattern_index: usize, + pub captures: &'a [QueryCapture<'a>], +} + +struct SyntaxMapCapturesLayer<'a> { + depth: usize, + captures: Peekable>>, + grammar: &'a Grammar, + _query_cursor: QueryCursorHandle, +} + +struct SyntaxMapMatchesLayer<'a> { + depth: usize, + matches: Peekable>>, + grammar: &'a Grammar, + _query_cursor: QueryCursorHandle, +} + #[derive(Clone)] struct SyntaxLayer { depth: usize, @@ -385,6 +429,100 @@ impl SyntaxSnapshot { self.layers = layers; } + pub fn captures<'a>( + &'a self, + range: Range, + buffer: &'a BufferSnapshot, + query: impl Fn(&Grammar) -> Option<&Query>, + ) -> SyntaxMapCaptures { + let mut result = SyntaxMapCaptures { layers: Vec::new() }; + for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) { + let query = if let Some(query) = query(grammar) { + query + } else { + continue; + }; + + let mut query_cursor = QueryCursorHandle::new(); + + // TODO - add a Tree-sitter API to remove the need for this. + let cursor = unsafe { + std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut()) + }; + + cursor.set_byte_range(range.clone()); + let captures = cursor.captures(query, node, TextProvider(buffer.as_rope())); + let mut layer = SyntaxMapCapturesLayer { + depth, + grammar, + captures: captures.peekable(), + _query_cursor: query_cursor, + }; + + if let Some(key) = layer.sort_key() { + let mut ix = 0; + while let Some(next_layer) = result.layers.get_mut(ix) { + if let Some(next_key) = next_layer.sort_key() { + if key > next_key { + ix += 1; + continue; + } + } + break; + } + result.layers.insert(ix, layer); + } + } + result + } + + pub fn matches<'a>( + &'a self, + range: Range, + buffer: &'a BufferSnapshot, + query: impl Fn(&Grammar) -> Option<&Query>, + ) -> SyntaxMapMatches { + let mut result = SyntaxMapMatches { layers: Vec::new() }; + for (grammar, depth, node) in self.layers_for_range(range.clone(), buffer) { + let query = if let Some(query) = query(grammar) { + query + } else { + continue; + }; + + let mut query_cursor = QueryCursorHandle::new(); + + // TODO - add a Tree-sitter API to remove the need for this. + let cursor = unsafe { + std::mem::transmute::<_, &'static mut QueryCursor>(query_cursor.deref_mut()) + }; + + cursor.set_byte_range(range.clone()); + let matches = cursor.matches(query, node, TextProvider(buffer.as_rope())); + let mut layer = SyntaxMapMatchesLayer { + depth, + grammar, + matches: matches.peekable(), + _query_cursor: query_cursor, + }; + + if let Some(key) = layer.sort_key() { + let mut ix = 0; + while let Some(next_layer) = result.layers.get_mut(ix) { + if let Some(next_key) = next_layer.sort_key() { + if key > next_key { + ix += 1; + continue; + } + } + break; + } + result.layers.insert(ix, layer); + } + } + result + } + pub fn layers(&self, buffer: &BufferSnapshot) -> Vec<(&Grammar, Node)> { self.layers .iter() @@ -408,7 +546,7 @@ impl SyntaxSnapshot { &self, range: Range, buffer: &BufferSnapshot, - ) -> Vec<(&Grammar, Node)> { + ) -> Vec<(&Grammar, usize, Node)> { let start = buffer.anchor_before(range.start.to_offset(buffer)); let end = buffer.anchor_after(range.end.to_offset(buffer)); @@ -424,6 +562,7 @@ impl SyntaxSnapshot { if let Some(grammar) = &layer.language.grammar { result.push(( grammar.as_ref(), + layer.depth, layer.tree.root_node_with_offset( layer.range.start.to_offset(buffer), layer.range.start.to_point(buffer).to_ts_point(), @@ -437,6 +576,60 @@ impl SyntaxSnapshot { } } +impl<'a> Iterator for SyntaxMapCaptures<'a> { + type Item = SyntaxMapCapture<'a>; + + fn next(&mut self) -> Option { + let layer = self.layers.first_mut()?; + let (mat, ix) = layer.captures.next()?; + + let capture = mat.captures[ix as usize]; + let grammar = layer.grammar; + let depth = layer.depth; + + if let Some(key) = layer.sort_key() { + let mut i = 1; + while let Some(later_layer) = self.layers.get_mut(i) { + if let Some(later_key) = later_layer.sort_key() { + if key > later_key { + i += 1; + continue; + } + } + break; + } + if i > 1 { + self.layers[0..i].rotate_left(1); + } + } else { + self.layers.remove(0); + } + + Some(SyntaxMapCapture { + grammar, + depth, + node: capture.node, + index: capture.index, + }) + } +} + +impl<'a> SyntaxMapCapturesLayer<'a> { + fn sort_key(&mut self) -> Option<(usize, Reverse, usize)> { + let (mat, ix) = self.captures.peek()?; + let range = &mat.captures[*ix].node.byte_range(); + Some((range.start, Reverse(range.end), self.depth)) + } +} + +impl<'a> SyntaxMapMatchesLayer<'a> { + fn sort_key(&mut self) -> Option<(usize, Reverse, usize)> { + let mat = self.matches.peek()?; + let range = mat.captures.first()?.node.start_byte()..mat.captures.last()?.node.end_byte(); + Some((range.start, Reverse(range.end), self.depth)) + } +} + fn join_ranges( a: impl Iterator>, b: impl Iterator>, @@ -875,10 +1068,10 @@ mod tests { "fn a() { dbg!(b.c(vec![d.«e»])) }", ]); - assert_node_ranges( + assert_capture_ranges( &syntax_map, &buffer, - "(field_identifier) @_", + &["field"], "fn a() { dbg!(b.«c»(vec![d.«e»])) }", ); } @@ -909,10 +1102,10 @@ mod tests { ", ]); - assert_node_ranges( + assert_capture_ranges( &syntax_map, &buffer, - "(struct_expression) @_", + &["struct"], " fn a() { b!(«B {}»); @@ -952,10 +1145,10 @@ mod tests { ", ]); - assert_node_ranges( + assert_capture_ranges( &syntax_map, &buffer, - "(field_identifier) @_", + &["field"], " fn a() { b!( @@ -1129,6 +1322,13 @@ mod tests { }, Some(tree_sitter_rust::language()), ) + .with_highlights_query( + r#" + (field_identifier) @field + (struct_expression) @struct + "#, + ) + .unwrap() .with_injection_query( r#" (macro_invocation @@ -1156,7 +1356,7 @@ mod tests { expected_layers.len(), "wrong number of layers" ); - for (i, ((_, node), expected_s_exp)) in + for (i, ((_, _, node), expected_s_exp)) in layers.iter().zip(expected_layers.iter()).enumerate() { let actual_s_exp = node.to_sexp(); @@ -1170,18 +1370,25 @@ mod tests { } } - fn assert_node_ranges( + fn assert_capture_ranges( syntax_map: &SyntaxMap, buffer: &BufferSnapshot, - query: &str, + highlight_query_capture_names: &[&str], marked_string: &str, ) { - let mut cursor = QueryCursorHandle::new(); let mut actual_ranges = Vec::>::new(); - for (grammar, node) in syntax_map.layers(buffer) { - let query = Query::new(grammar.ts_language, query).unwrap(); - for (mat, ix) in cursor.captures(&query, node, TextProvider(buffer.as_rope())) { - actual_ranges.push(mat.captures[ix].node.byte_range()); + for capture in syntax_map.captures(0..buffer.len(), buffer, |grammar| { + grammar.highlights_query.as_ref() + }) { + let name = &capture + .grammar + .highlights_query + .as_ref() + .unwrap() + .capture_names()[capture.index as usize]; + dbg!(capture.node, capture.index, name); + if highlight_query_capture_names.contains(&name.as_str()) { + actual_ranges.push(capture.node.byte_range()); } }