syntax_map.rs

  1use crate::{
  2    Grammar, Language, LanguageRegistry, QueryCursorHandle, TextProvider, ToTreeSitterPoint,
  3};
  4use collections::VecDeque;
  5use gpui::executor::Background;
  6use std::{borrow::Cow, cell::RefCell, cmp::Ordering, ops::Range, sync::Arc};
  7use sum_tree::{SeekTarget, SumTree};
  8use text::{Anchor, BufferSnapshot, Point, Rope, ToOffset};
  9use tree_sitter::{Parser, Tree};
 10use util::post_inc;
 11
 12thread_local! {
 13    static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
 14}
 15
 16#[derive(Default)]
 17pub struct SyntaxMap {
 18    next_layer_id: usize,
 19    snapshot: SyntaxMapSnapshot,
 20}
 21
 22#[derive(Clone, Default)]
 23pub struct SyntaxMapSnapshot {
 24    version: clock::Global,
 25    layers: SumTree<SyntaxLayer>,
 26}
 27
 28#[derive(Clone)]
 29struct SyntaxLayer {
 30    id: usize,
 31    parent_id: Option<usize>,
 32    range: SyntaxLayerRange,
 33    tree: tree_sitter::Tree,
 34    language: Arc<Language>,
 35}
 36
 37#[derive(Debug, Clone)]
 38struct SyntaxLayerSummary {
 39    range: Range<Anchor>,
 40    last_layer_range: Range<Anchor>,
 41}
 42
 43#[derive(Clone, Debug)]
 44struct SyntaxLayerRange(Range<Anchor>);
 45
 46impl SyntaxMap {
 47    pub fn new(
 48        executor: Arc<Background>,
 49        registry: Arc<LanguageRegistry>,
 50        language: Arc<Language>,
 51        text: BufferSnapshot,
 52        prev_set: Option<Self>,
 53    ) -> Self {
 54        let mut next_layer_id = 0;
 55        let mut layers = Vec::new();
 56        let mut injections = VecDeque::<(Option<usize>, _, Vec<tree_sitter::Range>)>::new();
 57
 58        injections.push_back((None, language, vec![]));
 59        while let Some((parent_id, language, ranges)) = injections.pop_front() {
 60            if let Some(grammar) = &language.grammar.as_deref() {
 61                let id = post_inc(&mut next_layer_id);
 62                let range = if let Some((first, last)) = ranges.first().zip(ranges.last()) {
 63                    text.anchor_before(first.start_byte)..text.anchor_after(last.end_byte)
 64                } else {
 65                    Anchor::MIN..Anchor::MAX
 66                };
 67                let tree = Self::parse_text(grammar, text.as_rope(), None, ranges);
 68                Self::get_injections(grammar, &text, &tree, id, &registry, &mut injections);
 69                layers.push(SyntaxLayer {
 70                    id,
 71                    parent_id,
 72                    range: SyntaxLayerRange(range),
 73                    tree,
 74                    language,
 75                });
 76            }
 77        }
 78
 79        layers.sort_unstable_by(|a, b| SeekTarget::cmp(&a.range, &b.range, &text));
 80
 81        Self {
 82            next_layer_id,
 83            snapshot: SyntaxMapSnapshot {
 84                layers: SumTree::from_iter(layers, &text),
 85                version: text.version,
 86            },
 87        }
 88    }
 89
 90    pub fn snapshot(&self) -> SyntaxMapSnapshot {
 91        self.snapshot.clone()
 92    }
 93
 94    fn interpolate(&mut self, text: &BufferSnapshot) {
 95        let edits = text
 96            .edits_since::<(Point, usize)>(&self.version)
 97            .map(|edit| {
 98                let (lines, bytes) = edit.flatten();
 99                tree_sitter::InputEdit {
100                    start_byte: bytes.new.start,
101                    old_end_byte: bytes.new.start + bytes.old.len(),
102                    new_end_byte: bytes.new.end,
103                    start_position: lines.new.start.to_ts_point(),
104                    old_end_position: (lines.new.start + (lines.old.end - lines.old.start))
105                        .to_ts_point(),
106                    new_end_position: lines.new.end.to_ts_point(),
107                }
108            })
109            .collect::<Vec<_>>();
110        if edits.is_empty() {
111            return;
112        }
113    }
114
115    fn get_injections(
116        grammar: &Grammar,
117        text: &BufferSnapshot,
118        tree: &Tree,
119        id: usize,
120        registry: &Arc<LanguageRegistry>,
121        output: &mut VecDeque<(Option<usize>, Arc<Language>, Vec<tree_sitter::Range>)>,
122    ) {
123        let config = if let Some(config) = &grammar.injection_config {
124            config
125        } else {
126            return;
127        };
128
129        let mut query_cursor = QueryCursorHandle::new();
130        for mat in query_cursor.matches(
131            &config.query,
132            tree.root_node(),
133            TextProvider(text.as_rope()),
134        ) {
135            let content_ranges = mat
136                .nodes_for_capture_index(config.content_capture_ix)
137                .map(|node| node.range())
138                .collect::<Vec<_>>();
139            if content_ranges.is_empty() {
140                continue;
141            }
142            let language_name = config.languages_by_pattern_ix[mat.pattern_index]
143                .as_ref()
144                .map(|s| Cow::Borrowed(s.as_ref()))
145                .or_else(|| {
146                    let ix = config.language_capture_ix?;
147                    let node = mat.nodes_for_capture_index(ix).next()?;
148                    Some(Cow::Owned(text.text_for_range(node.byte_range()).collect()))
149                });
150            if let Some(language_name) = language_name {
151                if let Some(language) = registry.get_language(language_name.as_ref()) {
152                    output.push_back((Some(id), language, content_ranges))
153                }
154            }
155        }
156    }
157
158    fn parse_text(
159        grammar: &Grammar,
160        text: &Rope,
161        old_tree: Option<Tree>,
162        ranges: Vec<tree_sitter::Range>,
163    ) -> Tree {
164        PARSER.with(|parser| {
165            let mut parser = parser.borrow_mut();
166            let mut chunks = text.chunks_in_range(0..text.len());
167            parser
168                .set_included_ranges(&ranges)
169                .expect("overlapping ranges");
170            parser
171                .set_language(grammar.ts_language)
172                .expect("incompatible grammar");
173            parser
174                .parse_with(
175                    &mut move |offset, _| {
176                        chunks.seek(offset);
177                        chunks.next().unwrap_or("").as_bytes()
178                    },
179                    old_tree.as_ref(),
180                )
181                .expect("invalid language")
182        })
183    }
184}
185
186impl SyntaxMapSnapshot {
187    pub fn layers_for_range<'a, T: ToOffset>(
188        &self,
189        range: Range<T>,
190        buffer: &BufferSnapshot,
191    ) -> Vec<(Tree, &Grammar)> {
192        let start = buffer.anchor_before(range.start.to_offset(buffer));
193        let end = buffer.anchor_after(range.end.to_offset(buffer));
194
195        let mut cursor = self.layers.filter::<_, ()>(|summary| {
196            let is_before_start = summary.range.end.cmp(&start, buffer).is_lt();
197            let is_after_end = summary.range.start.cmp(&end, buffer).is_gt();
198            !is_before_start && !is_after_end
199        });
200
201        let mut result = Vec::new();
202        cursor.next(buffer);
203        while let Some(item) = cursor.item() {
204            if let Some(grammar) = &item.language.grammar {
205                result.push((item.tree.clone(), grammar.as_ref()));
206            }
207            cursor.next(buffer)
208        }
209
210        result
211    }
212}
213
214impl std::ops::Deref for SyntaxMap {
215    type Target = SyntaxMapSnapshot;
216
217    fn deref(&self) -> &Self::Target {
218        &self.snapshot
219    }
220}
221
222impl Default for SyntaxLayerSummary {
223    fn default() -> Self {
224        Self {
225            range: Anchor::MAX..Anchor::MIN,
226            last_layer_range: Anchor::MIN..Anchor::MAX,
227        }
228    }
229}
230
231impl sum_tree::Summary for SyntaxLayerSummary {
232    type Context = BufferSnapshot;
233
234    fn add_summary(&mut self, other: &Self, buffer: &Self::Context) {
235        if other.range.start.cmp(&self.range.start, buffer).is_lt() {
236            self.range.start = other.range.start;
237        }
238        if other.range.end.cmp(&self.range.end, buffer).is_gt() {
239            self.range.end = other.range.end;
240        }
241        self.last_layer_range = other.last_layer_range.clone();
242    }
243}
244
245impl Default for SyntaxLayerRange {
246    fn default() -> Self {
247        Self(Anchor::MIN..Anchor::MAX)
248    }
249}
250
251impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerRange> for SyntaxLayerRange {
252    fn cmp(&self, cursor_location: &Self, buffer: &BufferSnapshot) -> Ordering {
253        self.0
254            .start
255            .cmp(&cursor_location.0.start, buffer)
256            .then_with(|| cursor_location.0.end.cmp(&self.0.end, buffer))
257    }
258}
259
260impl<'a> sum_tree::Dimension<'a, SyntaxLayerSummary> for SyntaxLayerRange {
261    fn add_summary(
262        &mut self,
263        summary: &'a SyntaxLayerSummary,
264        _: &<SyntaxLayerSummary as sum_tree::Summary>::Context,
265    ) {
266        self.0 = summary.last_layer_range.clone();
267    }
268}
269
270impl sum_tree::Item for SyntaxLayer {
271    type Summary = SyntaxLayerSummary;
272
273    fn summary(&self) -> Self::Summary {
274        SyntaxLayerSummary {
275            range: self.range.0.clone(),
276            last_layer_range: self.range.0.clone(),
277        }
278    }
279}
280
281impl std::fmt::Debug for SyntaxLayer {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("SyntaxLayer")
284            .field("id", &self.id)
285            .field("parent_id", &self.parent_id)
286            .field("range", &self.range)
287            .field("tree", &self.tree)
288            .finish()
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::LanguageConfig;
296    use gpui::MutableAppContext;
297    use text::{Buffer, Point};
298    use unindent::Unindent as _;
299
300    #[gpui::test]
301    fn test_syntax_map(cx: &mut MutableAppContext) {
302        let buffer = Buffer::new(
303            0,
304            0,
305            r#"
306                fn a() {
307                    assert_eq!(
308                        b(vec![C {}]),
309                        vec![d.e],
310                    );
311                    println!("{}", f(|_| true));
312                }
313            "#
314            .unindent(),
315        );
316
317        let executor = cx.background().clone();
318        let registry = Arc::new(LanguageRegistry::test());
319        let language = Arc::new(rust_lang());
320        let snapshot = buffer.snapshot();
321        registry.add(language.clone());
322
323        let syntax_map = SyntaxMap::new(executor, registry, language, snapshot.clone(), None);
324
325        let layers = syntax_map.layers_for_range(Point::new(0, 0)..Point::new(0, 1), &snapshot);
326        assert_layers(
327            &layers,
328            &["(source_file (function_item name: (identifier)..."],
329        );
330
331        let layers = syntax_map.layers_for_range(Point::new(2, 0)..Point::new(2, 0), &snapshot);
332        assert_layers(
333            &layers,
334            &[
335                "...(function_item ... (block (expression_statement (macro_invocation...",
336                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
337            ],
338        );
339
340        let layers = syntax_map.layers_for_range(Point::new(2, 14)..Point::new(2, 16), &snapshot);
341        assert_layers(
342            &layers,
343            &[
344                "...(function_item ...",
345                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
346                "...(array_expression (struct_expression ...",
347            ],
348        );
349
350        let layers = syntax_map.layers_for_range(Point::new(3, 14)..Point::new(3, 16), &snapshot);
351        assert_layers(
352            &layers,
353            &[
354                "...(function_item ...",
355                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
356                "...(array_expression (field_expression ...",
357            ],
358        );
359
360        let layers = syntax_map.layers_for_range(Point::new(5, 12)..Point::new(5, 16), &snapshot);
361        assert_layers(
362            &layers,
363            &[
364                "...(function_item ...",
365                "...(call_expression ... (arguments (closure_expression ...",
366            ],
367        );
368    }
369
370    fn rust_lang() -> Language {
371        Language::new(
372            LanguageConfig {
373                name: "Rust".into(),
374                path_suffixes: vec!["rs".to_string()],
375                ..Default::default()
376            },
377            Some(tree_sitter_rust::language()),
378        )
379        .with_injection_query(
380            r#"
381                (macro_invocation
382                    (token_tree) @content
383                    (#set! "language" "rust"))
384            "#,
385        )
386        .unwrap()
387    }
388
389    fn assert_layers(layers: &[(Tree, &Grammar)], expected_layers: &[&str]) {
390        assert_eq!(
391            layers.len(),
392            expected_layers.len(),
393            "wrong number of layers"
394        );
395        for (i, (layer, expected_s_exp)) in layers.iter().zip(expected_layers.iter()).enumerate() {
396            let actual_s_exp = layer.0.root_node().to_sexp();
397            assert!(
398                string_contains_sequence(
399                    &actual_s_exp,
400                    &expected_s_exp.split("...").collect::<Vec<_>>()
401                ),
402                "layer {i}:\n\nexpected: {expected_s_exp}\nactual:   {actual_s_exp}",
403            );
404        }
405    }
406
407    pub fn string_contains_sequence(text: &str, parts: &[&str]) -> bool {
408        let mut last_part_end = 0;
409        for part in parts {
410            if let Some(start_ix) = text[last_part_end..].find(part) {
411                last_part_end = start_ix + part.len();
412            } else {
413                return false;
414            }
415        }
416        true
417    }
418}