syntax_map.rs

   1use crate::{
   2    Grammar, InjectionConfig, Language, LanguageRegistry, QueryCursorHandle, TextProvider,
   3    ToTreeSitterPoint,
   4};
   5use std::{
   6    borrow::Cow, cell::RefCell, cmp::Ordering, collections::BinaryHeap, ops::Range, sync::Arc,
   7};
   8use sum_tree::{Bias, SeekTarget, SumTree};
   9use text::{Anchor, BufferSnapshot, OffsetRangeExt, Point, Rope, ToOffset, ToPoint};
  10use tree_sitter::{Node, Parser, Tree};
  11
  12thread_local! {
  13    static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
  14}
  15
  16#[derive(Default)]
  17pub struct SyntaxMap {
  18    parsed_version: clock::Global,
  19    interpolated_version: clock::Global,
  20    snapshot: SyntaxSnapshot,
  21    language_registry: Option<Arc<LanguageRegistry>>,
  22}
  23
  24#[derive(Clone, Default)]
  25pub struct SyntaxSnapshot {
  26    layers: SumTree<SyntaxLayer>,
  27}
  28
  29#[derive(Clone)]
  30struct SyntaxLayer {
  31    depth: usize,
  32    range: Range<Anchor>,
  33    tree: tree_sitter::Tree,
  34    language: Arc<Language>,
  35}
  36
  37#[derive(Debug, Clone)]
  38struct SyntaxLayerSummary {
  39    max_depth: usize,
  40    range: Range<Anchor>,
  41    last_layer_range: Range<Anchor>,
  42}
  43
  44#[derive(Clone, Debug)]
  45struct DepthAndRange(usize, Range<Anchor>);
  46
  47#[derive(Clone, Debug)]
  48struct DepthAndMaxPosition(usize, Anchor);
  49
  50#[derive(Clone, Debug)]
  51struct DepthAndRangeOrMaxPosition(DepthAndRange, DepthAndMaxPosition);
  52
  53struct ReparseStep {
  54    depth: usize,
  55    language: Arc<Language>,
  56    ranges: Vec<tree_sitter::Range>,
  57    range: Range<Anchor>,
  58}
  59
  60#[derive(Debug, PartialEq, Eq)]
  61struct ChangedRegion {
  62    depth: usize,
  63    range: Range<Anchor>,
  64}
  65
  66impl SyntaxMap {
  67    pub fn new() -> Self {
  68        Self::default()
  69    }
  70
  71    pub fn set_language_registry(&mut self, registry: Arc<LanguageRegistry>) {
  72        self.language_registry = Some(registry);
  73    }
  74
  75    pub fn snapshot(&self) -> SyntaxSnapshot {
  76        self.snapshot.clone()
  77    }
  78
  79    pub fn interpolate(&mut self, text: &BufferSnapshot) {
  80        self.snapshot.interpolate(&self.interpolated_version, text);
  81        self.interpolated_version = text.version.clone();
  82    }
  83
  84    pub fn reparse(&mut self, language: Arc<Language>, text: &BufferSnapshot) {
  85        if !self.interpolated_version.observed_all(&text.version) {
  86            self.interpolate(text);
  87        }
  88
  89        self.snapshot.reparse(
  90            &self.parsed_version,
  91            text,
  92            self.language_registry.clone(),
  93            language,
  94        );
  95        self.parsed_version = text.version.clone();
  96    }
  97}
  98
  99impl SyntaxSnapshot {
 100    pub fn interpolate(&mut self, from_version: &clock::Global, text: &BufferSnapshot) {
 101        let edits = text
 102            .edits_since::<(usize, Point)>(&from_version)
 103            .collect::<Vec<_>>();
 104        if edits.is_empty() {
 105            return;
 106        }
 107
 108        let mut layers = SumTree::new();
 109        let mut edits_for_depth = &edits[..];
 110        let mut cursor = self.layers.cursor::<SyntaxLayerSummary>();
 111        cursor.next(text);
 112
 113        'outer: loop {
 114            let depth = cursor.end(text).max_depth;
 115
 116            // Preserve any layers at this depth that precede the first edit.
 117            if let Some(first_edit) = edits_for_depth.first() {
 118                let target = DepthAndMaxPosition(depth, text.anchor_before(first_edit.new.start.0));
 119                if target.cmp(&cursor.start(), text).is_gt() {
 120                    let slice = cursor.slice(&target, Bias::Left, text);
 121                    layers.push_tree(slice, text);
 122                }
 123            }
 124            // If this layer follows all of the edits, then preserve it and any
 125            // subsequent layers at this same depth.
 126            else {
 127                layers.push_tree(
 128                    cursor.slice(
 129                        &DepthAndRange(depth + 1, Anchor::MIN..Anchor::MAX),
 130                        Bias::Left,
 131                        text,
 132                    ),
 133                    text,
 134                );
 135                edits_for_depth = &edits[..];
 136                continue;
 137            };
 138
 139            let layer = if let Some(layer) = cursor.item() {
 140                layer
 141            } else {
 142                break;
 143            };
 144
 145            let mut endpoints = text
 146                .summaries_for_anchors::<(usize, Point), _>([&layer.range.start, &layer.range.end]);
 147            let layer_range = endpoints.next().unwrap()..endpoints.next().unwrap();
 148            let start_byte = layer_range.start.0;
 149            let start_point = layer_range.start.1;
 150            let end_byte = layer_range.end.0;
 151
 152            // Ignore edits that end before the start of this layer, and don't consider them
 153            // for any subsequent layers at this same depth.
 154            loop {
 155                if let Some(edit) = edits_for_depth.first() {
 156                    if edit.new.end.0 < start_byte {
 157                        edits_for_depth = &edits_for_depth[1..];
 158                    } else {
 159                        break;
 160                    }
 161                } else {
 162                    continue 'outer;
 163                }
 164            }
 165
 166            let mut layer = layer.clone();
 167            for edit in edits_for_depth {
 168                // Ignore any edits that follow this layer.
 169                if edit.new.start.0 > end_byte {
 170                    break;
 171                }
 172
 173                // Apply any edits that intersect this layer to the layer's syntax tree.
 174                let tree_edit = if edit.new.start.0 >= start_byte {
 175                    tree_sitter::InputEdit {
 176                        start_byte: edit.new.start.0 - start_byte,
 177                        old_end_byte: edit.new.start.0 - start_byte
 178                            + (edit.old.end.0 - edit.old.start.0),
 179                        new_end_byte: edit.new.end.0 - start_byte,
 180                        start_position: (edit.new.start.1 - start_point).to_ts_point(),
 181                        old_end_position: (edit.new.start.1 - start_point
 182                            + (edit.old.end.1 - edit.old.start.1))
 183                            .to_ts_point(),
 184                        new_end_position: (edit.new.end.1 - start_point).to_ts_point(),
 185                    }
 186                } else {
 187                    tree_sitter::InputEdit {
 188                        start_byte: 0,
 189                        old_end_byte: edit.new.end.0 - start_byte,
 190                        new_end_byte: 0,
 191                        start_position: Default::default(),
 192                        old_end_position: (edit.new.end.1 - start_point).to_ts_point(),
 193                        new_end_position: Default::default(),
 194                    }
 195                };
 196
 197                layer.tree.edit(&tree_edit);
 198                if edit.new.start.0 < start_byte {
 199                    break;
 200                }
 201            }
 202
 203            layers.push(layer, text);
 204            cursor.next(text);
 205        }
 206
 207        layers.push_tree(cursor.suffix(&text), &text);
 208        drop(cursor);
 209        self.layers = layers;
 210    }
 211
 212    pub fn reparse(
 213        &mut self,
 214        from_version: &clock::Global,
 215        text: &BufferSnapshot,
 216        registry: Option<Arc<LanguageRegistry>>,
 217        language: Arc<Language>,
 218    ) {
 219        let edits = text.edits_since::<usize>(from_version).collect::<Vec<_>>();
 220        if edits.is_empty() {
 221            return;
 222        }
 223
 224        let max_depth = self.layers.summary().max_depth;
 225        let mut cursor = self.layers.cursor::<SyntaxLayerSummary>();
 226        cursor.next(&text);
 227        let mut layers = SumTree::new();
 228
 229        let mut changed_regions = Vec::<ChangedRegion>::new();
 230        let mut queue = BinaryHeap::new();
 231        queue.push(ReparseStep {
 232            depth: 0,
 233            language: language.clone(),
 234            ranges: Vec::new(),
 235            range: Anchor::MIN..Anchor::MAX,
 236        });
 237
 238        loop {
 239            let step = queue.pop();
 240            let (depth, range) = if let Some(step) = &step {
 241                (step.depth, step.range.clone())
 242            } else {
 243                (max_depth + 1, Anchor::MAX..Anchor::MAX)
 244            };
 245
 246            let target = DepthAndRange(depth, range.clone());
 247            let mut done = cursor.item().is_none();
 248            while !done && target.cmp(cursor.start(), &text).is_gt() {
 249                let bounded_target = DepthAndRangeOrMaxPosition(
 250                    target.clone(),
 251                    changed_regions
 252                        .first()
 253                        .map_or(DepthAndMaxPosition(usize::MAX, Anchor::MAX), |region| {
 254                            DepthAndMaxPosition(region.depth, region.range.start)
 255                        }),
 256                );
 257                if bounded_target.cmp(&cursor.start(), &text).is_gt() {
 258                    let slice = cursor.slice(&bounded_target, Bias::Left, text);
 259                    layers.push_tree(slice, &text);
 260                }
 261
 262                while target.cmp(&cursor.end(text), text).is_gt() {
 263                    let layer = if let Some(layer) = cursor.item() {
 264                        layer
 265                    } else {
 266                        break;
 267                    };
 268
 269                    if layer_is_changed(layer, text, &changed_regions) {
 270                        ChangedRegion {
 271                            depth: depth + 1,
 272                            range: layer.range.clone(),
 273                        }
 274                        .insert(text, &mut changed_regions);
 275                    } else {
 276                        layers.push(layer.clone(), text);
 277                    }
 278                    cursor.next(text);
 279                }
 280
 281                done = true;
 282                changed_regions.retain(|region| {
 283                    if region.depth > depth
 284                        || (region.depth == depth
 285                            && region.range.end.cmp(&range.start, text).is_gt())
 286                    {
 287                        true
 288                    } else {
 289                        done = false;
 290                        false
 291                    }
 292                });
 293            }
 294
 295            let (ranges, language) = if let Some(step) = step {
 296                (step.ranges, step.language)
 297            } else {
 298                break;
 299            };
 300
 301            let start_point;
 302            let start_byte;
 303            let end_byte;
 304            if let Some((first, last)) = ranges.first().zip(ranges.last()) {
 305                start_point = first.start_point;
 306                start_byte = first.start_byte;
 307                end_byte = last.end_byte;
 308            } else {
 309                start_point = Point::zero().to_ts_point();
 310                start_byte = 0;
 311                end_byte = text.len();
 312            };
 313
 314            let mut old_layer = cursor.item();
 315            if let Some(layer) = old_layer {
 316                if layer.range.to_offset(text) == (start_byte..end_byte) {
 317                    cursor.next(&text);
 318                } else {
 319                    old_layer = None;
 320                }
 321            }
 322
 323            let grammar = if let Some(grammar) = language.grammar.as_deref() {
 324                grammar
 325            } else {
 326                continue;
 327            };
 328
 329            let tree;
 330            let changed_ranges;
 331            if let Some(old_layer) = old_layer {
 332                tree = parse_text(
 333                    grammar,
 334                    text.as_rope(),
 335                    Some(old_layer.tree.clone()),
 336                    ranges,
 337                );
 338                changed_ranges = join_ranges(
 339                    edits
 340                        .iter()
 341                        .map(|e| e.new.clone())
 342                        .filter(|range| range.start < end_byte && range.end > start_byte),
 343                    old_layer
 344                        .tree
 345                        .changed_ranges(&tree)
 346                        .map(|r| start_byte + r.start_byte..start_byte + r.end_byte),
 347                );
 348            } else {
 349                tree = parse_text(grammar, text.as_rope(), None, ranges);
 350                changed_ranges = vec![start_byte..end_byte];
 351            }
 352
 353            layers.push(
 354                SyntaxLayer {
 355                    depth,
 356                    range,
 357                    tree: tree.clone(),
 358                    language: language.clone(),
 359                },
 360                &text,
 361            );
 362
 363            if let (Some((config, registry)), false) = (
 364                grammar.injection_config.as_ref().zip(registry.as_ref()),
 365                changed_ranges.is_empty(),
 366            ) {
 367                let depth = depth + 1;
 368                for range in &changed_ranges {
 369                    ChangedRegion {
 370                        depth,
 371                        range: text.anchor_before(range.start)..text.anchor_after(range.end),
 372                    }
 373                    .insert(text, &mut changed_regions);
 374                }
 375                get_injections(
 376                    config,
 377                    text,
 378                    tree.root_node_with_offset(start_byte, start_point),
 379                    registry,
 380                    depth,
 381                    &changed_ranges,
 382                    &mut queue,
 383                );
 384            }
 385        }
 386
 387        drop(cursor);
 388        self.layers = layers;
 389    }
 390
 391    pub fn layers(&self, buffer: &BufferSnapshot) -> Vec<(&Grammar, Node)> {
 392        self.layers
 393            .iter()
 394            .filter_map(|layer| {
 395                if let Some(grammar) = &layer.language.grammar {
 396                    Some((
 397                        grammar.as_ref(),
 398                        layer.tree.root_node_with_offset(
 399                            layer.range.start.to_offset(buffer),
 400                            layer.range.start.to_point(buffer).to_ts_point(),
 401                        ),
 402                    ))
 403                } else {
 404                    None
 405                }
 406            })
 407            .collect()
 408    }
 409
 410    pub fn layers_for_range<'a, T: ToOffset>(
 411        &self,
 412        range: Range<T>,
 413        buffer: &BufferSnapshot,
 414    ) -> Vec<(&Grammar, Node)> {
 415        let start = buffer.anchor_before(range.start.to_offset(buffer));
 416        let end = buffer.anchor_after(range.end.to_offset(buffer));
 417
 418        let mut cursor = self.layers.filter::<_, ()>(|summary| {
 419            let is_before_start = summary.range.end.cmp(&start, buffer).is_lt();
 420            let is_after_end = summary.range.start.cmp(&end, buffer).is_gt();
 421            !is_before_start && !is_after_end
 422        });
 423
 424        let mut result = Vec::new();
 425        cursor.next(buffer);
 426        while let Some(layer) = cursor.item() {
 427            if let Some(grammar) = &layer.language.grammar {
 428                result.push((
 429                    grammar.as_ref(),
 430                    layer.tree.root_node_with_offset(
 431                        layer.range.start.to_offset(buffer),
 432                        layer.range.start.to_point(buffer).to_ts_point(),
 433                    ),
 434                ));
 435            }
 436            cursor.next(buffer)
 437        }
 438
 439        result
 440    }
 441}
 442
 443fn join_ranges(
 444    a: impl Iterator<Item = Range<usize>>,
 445    b: impl Iterator<Item = Range<usize>>,
 446) -> Vec<Range<usize>> {
 447    let mut result = Vec::<Range<usize>>::new();
 448    let mut a = a.peekable();
 449    let mut b = b.peekable();
 450    loop {
 451        let range = match (a.peek(), b.peek()) {
 452            (Some(range_a), Some(range_b)) => {
 453                if range_a.start < range_b.start {
 454                    a.next().unwrap()
 455                } else {
 456                    b.next().unwrap()
 457                }
 458            }
 459            (None, Some(_)) => b.next().unwrap(),
 460            (Some(_), None) => a.next().unwrap(),
 461            (None, None) => break,
 462        };
 463
 464        if let Some(last) = result.last_mut() {
 465            if range.start <= last.end {
 466                last.end = last.end.max(range.end);
 467                continue;
 468            }
 469        }
 470        result.push(range);
 471    }
 472    result
 473}
 474
 475fn parse_text(
 476    grammar: &Grammar,
 477    text: &Rope,
 478    old_tree: Option<Tree>,
 479    mut ranges: Vec<tree_sitter::Range>,
 480) -> Tree {
 481    let (start_byte, start_point) = ranges
 482        .first()
 483        .map(|range| (range.start_byte, Point::from_ts_point(range.start_point)))
 484        .unwrap_or_default();
 485
 486    for range in &mut ranges {
 487        range.start_byte -= start_byte;
 488        range.end_byte -= start_byte;
 489        range.start_point = (Point::from_ts_point(range.start_point) - start_point).to_ts_point();
 490        range.end_point = (Point::from_ts_point(range.end_point) - start_point).to_ts_point();
 491    }
 492
 493    PARSER.with(|parser| {
 494        let mut parser = parser.borrow_mut();
 495        let mut chunks = text.chunks_in_range(start_byte..text.len());
 496        parser
 497            .set_included_ranges(&ranges)
 498            .expect("overlapping ranges");
 499        parser
 500            .set_language(grammar.ts_language)
 501            .expect("incompatible grammar");
 502        parser
 503            .parse_with(
 504                &mut move |offset, _| {
 505                    chunks.seek(start_byte + offset);
 506                    chunks.next().unwrap_or("").as_bytes()
 507                },
 508                old_tree.as_ref(),
 509            )
 510            .expect("invalid language")
 511    })
 512}
 513
 514fn get_injections(
 515    config: &InjectionConfig,
 516    text: &BufferSnapshot,
 517    node: Node,
 518    language_registry: &LanguageRegistry,
 519    depth: usize,
 520    query_ranges: &[Range<usize>],
 521    queue: &mut BinaryHeap<ReparseStep>,
 522) -> bool {
 523    let mut result = false;
 524    let mut query_cursor = QueryCursorHandle::new();
 525    let mut prev_match = None;
 526    for query_range in query_ranges {
 527        query_cursor.set_byte_range(query_range.start..query_range.end);
 528        for mat in query_cursor.matches(&config.query, node, TextProvider(text.as_rope())) {
 529            let content_ranges = mat
 530                .nodes_for_capture_index(config.content_capture_ix)
 531                .map(|node| node.range())
 532                .collect::<Vec<_>>();
 533            if content_ranges.is_empty() {
 534                continue;
 535            }
 536
 537            // Avoid duplicate matches if two changed ranges intersect the same injection.
 538            let content_range =
 539                content_ranges.first().unwrap().start_byte..content_ranges.last().unwrap().end_byte;
 540            if let Some((last_pattern_ix, last_range)) = &prev_match {
 541                if mat.pattern_index == *last_pattern_ix && content_range == *last_range {
 542                    continue;
 543                }
 544            }
 545            prev_match = Some((mat.pattern_index, content_range.clone()));
 546
 547            let language_name = config.languages_by_pattern_ix[mat.pattern_index]
 548                .as_ref()
 549                .map(|s| Cow::Borrowed(s.as_ref()))
 550                .or_else(|| {
 551                    let ix = config.language_capture_ix?;
 552                    let node = mat.nodes_for_capture_index(ix).next()?;
 553                    Some(Cow::Owned(text.text_for_range(node.byte_range()).collect()))
 554                });
 555
 556            if let Some(language_name) = language_name {
 557                if let Some(language) = language_registry.get_language(language_name.as_ref()) {
 558                    result = true;
 559                    let range = text.anchor_before(content_range.start)
 560                        ..text.anchor_after(content_range.end);
 561                    queue.push(ReparseStep {
 562                        depth,
 563                        language,
 564                        ranges: content_ranges,
 565                        range,
 566                    })
 567                }
 568            }
 569        }
 570    }
 571    result
 572}
 573
 574fn layer_is_changed(
 575    layer: &SyntaxLayer,
 576    text: &BufferSnapshot,
 577    changed_regions: &[ChangedRegion],
 578) -> bool {
 579    changed_regions.iter().any(|region| {
 580        let same_depth = region.depth == layer.depth;
 581        let is_before_layer = region.range.end.cmp(&layer.range.start, text).is_le();
 582        let is_after_layer = region.range.start.cmp(&layer.range.end, text).is_ge();
 583        same_depth && !is_before_layer && !is_after_layer
 584    })
 585}
 586
 587impl std::ops::Deref for SyntaxMap {
 588    type Target = SyntaxSnapshot;
 589
 590    fn deref(&self) -> &Self::Target {
 591        &self.snapshot
 592    }
 593}
 594
 595impl PartialEq for ReparseStep {
 596    fn eq(&self, _: &Self) -> bool {
 597        false
 598    }
 599}
 600
 601impl Eq for ReparseStep {}
 602
 603impl PartialOrd for ReparseStep {
 604    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 605        Some(self.cmp(&other))
 606    }
 607}
 608
 609impl Ord for ReparseStep {
 610    fn cmp(&self, other: &Self) -> Ordering {
 611        let range_a = self.range();
 612        let range_b = other.range();
 613        Ord::cmp(&other.depth, &self.depth)
 614            .then_with(|| Ord::cmp(&range_b.start, &range_a.start))
 615            .then_with(|| Ord::cmp(&range_a.end, &range_b.end))
 616    }
 617}
 618
 619impl ReparseStep {
 620    fn range(&self) -> Range<usize> {
 621        let start = self.ranges.first().map_or(0, |r| r.start_byte);
 622        let end = self.ranges.last().map_or(0, |r| r.end_byte);
 623        start..end
 624    }
 625}
 626
 627impl ChangedRegion {
 628    fn insert(self, text: &BufferSnapshot, set: &mut Vec<Self>) {
 629        if let Err(ix) = set.binary_search_by(|probe| probe.cmp(&self, text)) {
 630            set.insert(ix, self);
 631        }
 632    }
 633
 634    fn cmp(&self, other: &Self, buffer: &BufferSnapshot) -> Ordering {
 635        let range_a = &self.range;
 636        let range_b = &other.range;
 637        Ord::cmp(&self.depth, &other.depth)
 638            .then_with(|| range_a.start.cmp(&range_b.start, buffer))
 639            .then_with(|| range_b.end.cmp(&range_a.end, buffer))
 640    }
 641}
 642
 643impl Default for SyntaxLayerSummary {
 644    fn default() -> Self {
 645        Self {
 646            max_depth: 0,
 647            range: Anchor::MAX..Anchor::MIN,
 648            last_layer_range: Anchor::MIN..Anchor::MAX,
 649        }
 650    }
 651}
 652
 653impl sum_tree::Summary for SyntaxLayerSummary {
 654    type Context = BufferSnapshot;
 655
 656    fn add_summary(&mut self, other: &Self, buffer: &Self::Context) {
 657        if other.max_depth > self.max_depth {
 658            *self = other.clone();
 659        } else {
 660            if other.range.start.cmp(&self.range.start, buffer).is_lt() {
 661                self.range.start = other.range.start;
 662            }
 663            if other.range.end.cmp(&self.range.end, buffer).is_gt() {
 664                self.range.end = other.range.end;
 665            }
 666            self.last_layer_range = other.last_layer_range.clone();
 667        }
 668    }
 669}
 670
 671impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRange {
 672    fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
 673        Ord::cmp(&self.0, &cursor_location.max_depth)
 674            .then_with(|| {
 675                self.1
 676                    .start
 677                    .cmp(&cursor_location.last_layer_range.start, buffer)
 678            })
 679            .then_with(|| {
 680                cursor_location
 681                    .last_layer_range
 682                    .end
 683                    .cmp(&self.1.end, buffer)
 684            })
 685    }
 686}
 687
 688impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndMaxPosition {
 689    fn cmp(&self, cursor_location: &SyntaxLayerSummary, text: &BufferSnapshot) -> Ordering {
 690        Ord::cmp(&self.0, &cursor_location.max_depth)
 691            .then_with(|| self.1.cmp(&cursor_location.range.end, text))
 692    }
 693}
 694
 695impl<'a> SeekTarget<'a, SyntaxLayerSummary, SyntaxLayerSummary> for DepthAndRangeOrMaxPosition {
 696    fn cmp(&self, cursor_location: &SyntaxLayerSummary, buffer: &BufferSnapshot) -> Ordering {
 697        if self.1.cmp(cursor_location, buffer).is_le() {
 698            return Ordering::Less;
 699        } else {
 700            self.0.cmp(cursor_location, buffer)
 701        }
 702    }
 703}
 704
 705impl sum_tree::Item for SyntaxLayer {
 706    type Summary = SyntaxLayerSummary;
 707
 708    fn summary(&self) -> Self::Summary {
 709        SyntaxLayerSummary {
 710            max_depth: self.depth,
 711            range: self.range.clone(),
 712            last_layer_range: self.range.clone(),
 713        }
 714    }
 715}
 716
 717impl std::fmt::Debug for SyntaxLayer {
 718    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 719        f.debug_struct("SyntaxLayer")
 720            .field("depth", &self.depth)
 721            .field("range", &self.range)
 722            .field("tree", &self.tree)
 723            .finish()
 724    }
 725}
 726
 727#[cfg(test)]
 728mod tests {
 729    use super::*;
 730    use crate::LanguageConfig;
 731    use text::{Buffer, Point};
 732    use tree_sitter::Query;
 733    use unindent::Unindent as _;
 734    use util::test::marked_text_ranges;
 735
 736    #[gpui::test]
 737    fn test_syntax_map_layers_for_range() {
 738        let registry = Arc::new(LanguageRegistry::test());
 739        let language = Arc::new(rust_lang());
 740        registry.add(language.clone());
 741
 742        let mut buffer = Buffer::new(
 743            0,
 744            0,
 745            r#"
 746                fn a() {
 747                    assert_eq!(
 748                        b(vec![C {}]),
 749                        vec![d.e],
 750                    );
 751                    println!("{}", f(|_| true));
 752                }
 753            "#
 754            .unindent(),
 755        );
 756
 757        let mut syntax_map = SyntaxMap::new();
 758        syntax_map.set_language_registry(registry.clone());
 759        syntax_map.reparse(language.clone(), &buffer);
 760
 761        assert_layers_for_range(
 762            &syntax_map,
 763            &buffer,
 764            Point::new(2, 0)..Point::new(2, 0),
 765            &[
 766                "...(function_item ... (block (expression_statement (macro_invocation...",
 767                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
 768            ],
 769        );
 770        assert_layers_for_range(
 771            &syntax_map,
 772            &buffer,
 773            Point::new(2, 14)..Point::new(2, 16),
 774            &[
 775                "...(function_item ...",
 776                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
 777                "...(array_expression (struct_expression ...",
 778            ],
 779        );
 780        assert_layers_for_range(
 781            &syntax_map,
 782            &buffer,
 783            Point::new(3, 14)..Point::new(3, 16),
 784            &[
 785                "...(function_item ...",
 786                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
 787                "...(array_expression (field_expression ...",
 788            ],
 789        );
 790        assert_layers_for_range(
 791            &syntax_map,
 792            &buffer,
 793            Point::new(5, 12)..Point::new(5, 16),
 794            &[
 795                "...(function_item ...",
 796                "...(call_expression ... (arguments (closure_expression ...",
 797            ],
 798        );
 799
 800        // Replace a vec! macro invocation with a plain slice, removing a syntactic layer.
 801        let macro_name_range = range_for_text(&buffer, "vec!");
 802        buffer.edit([(macro_name_range, "&")]);
 803        syntax_map.interpolate(&buffer);
 804        syntax_map.reparse(language.clone(), &buffer);
 805
 806        assert_layers_for_range(
 807            &syntax_map,
 808            &buffer,
 809            Point::new(2, 14)..Point::new(2, 16),
 810            &[
 811                "...(function_item ...",
 812                "...(tuple_expression (call_expression ... arguments: (arguments (reference_expression value: (array_expression...",
 813            ],
 814        );
 815
 816        // Put the vec! macro back, adding back the syntactic layer.
 817        buffer.undo();
 818        syntax_map.interpolate(&buffer);
 819        syntax_map.reparse(language.clone(), &buffer);
 820
 821        assert_layers_for_range(
 822            &syntax_map,
 823            &buffer,
 824            Point::new(2, 14)..Point::new(2, 16),
 825            &[
 826                "...(function_item ...",
 827                "...(tuple_expression (call_expression ... arguments: (arguments (macro_invocation...",
 828                "...(array_expression (struct_expression ...",
 829            ],
 830        );
 831    }
 832
 833    #[gpui::test]
 834    fn test_typing_multiple_new_injections() {
 835        let (buffer, syntax_map) = test_edit_sequence(&[
 836            "fn a() { dbg }",
 837            "fn a() { dbg«!» }",
 838            "fn a() { dbg!«()» }",
 839            "fn a() { dbg!(«b») }",
 840            "fn a() { dbg!(b«.») }",
 841            "fn a() { dbg!(b.«c») }",
 842            "fn a() { dbg!(b.c«()») }",
 843            "fn a() { dbg!(b.c(«vec»)) }",
 844            "fn a() { dbg!(b.c(vec«!»)) }",
 845            "fn a() { dbg!(b.c(vec!«[]»)) }",
 846            "fn a() { dbg!(b.c(vec![«d»])) }",
 847            "fn a() { dbg!(b.c(vec![d«.»])) }",
 848            "fn a() { dbg!(b.c(vec![d.«e»])) }",
 849        ]);
 850
 851        assert_node_ranges(
 852            &syntax_map,
 853            &buffer,
 854            "(field_identifier) @_",
 855            "fn a() { dbg!(b.«c»(vec![d.«e»])) }",
 856        );
 857    }
 858
 859    #[gpui::test]
 860    fn test_pasting_new_injection_line_between_others() {
 861        let (buffer, syntax_map) = test_edit_sequence(&[
 862            "
 863                fn a() {
 864                    b!(B {});
 865                    c!(C {});
 866                    d!(D {});
 867                    e!(E {});
 868                    f!(F {});
 869                }
 870            ",
 871            "
 872                fn a() {
 873                    b!(B {});
 874                    c!(C {});
 875                    «g!(G {});
 876                    »d!(D {});
 877                    e!(E {});
 878                    f!(F {});
 879                }
 880            ",
 881        ]);
 882
 883        assert_node_ranges(
 884            &syntax_map,
 885            &buffer,
 886            "(struct_expression) @_",
 887            "
 888            fn a() {
 889                b!(«B {}»);
 890                c!(«C {}»);
 891                g!(«G {}»);
 892                d!(«D {}»);
 893                e!(«E {}»);
 894                f!(«F {}»);
 895            }
 896            ",
 897        );
 898    }
 899
 900    #[gpui::test]
 901    fn test_joining_injections_with_child_injections() {
 902        let (buffer, syntax_map) = test_edit_sequence(&[
 903            "
 904                fn a() {
 905                    b!(
 906                        c![one.two.three],
 907                        d![four.five.six],
 908                    );
 909                    e!(
 910                        f![seven.eight],
 911                    );
 912                }
 913            ",
 914            "
 915                fn a() {
 916                    b!(
 917                        c![one.two.three],
 918                        d![four.five.six],
 919                    ˇ    f![seven.eight],
 920                    );
 921                }
 922            ",
 923        ]);
 924
 925        assert_node_ranges(
 926            &syntax_map,
 927            &buffer,
 928            "(field_identifier) @_",
 929            "
 930            fn a() {
 931                b!(
 932                    c![one.«two».«three»],
 933                    d![four.«five».«six»],
 934                    f![seven.«eight»],
 935                );
 936            }
 937            ",
 938        );
 939    }
 940
 941    #[gpui::test]
 942    fn test_editing_edges_of_injection() {
 943        test_edit_sequence(&[
 944            "
 945                fn a() {
 946                    b!(c!())
 947                }
 948            ",
 949            "
 950                fn a() {
 951                    «d»!(c!())
 952                }
 953            ",
 954            "
 955                fn a() {
 956                    «e»d!(c!())
 957                }
 958            ",
 959            "
 960                fn a() {
 961                    ed!«[»c!()«]»
 962                }
 963            ",
 964        ]);
 965    }
 966
 967    #[gpui::test]
 968    fn test_edits_preceding_and_intersecting_injection() {
 969        test_edit_sequence(&[
 970            //
 971            "const aaaaaaaaaaaa: B = c!(d(e.f));",
 972            "const aˇa: B = c!(d(eˇ));",
 973        ]);
 974    }
 975
 976    #[gpui::test]
 977    fn test_non_local_changes_create_injections() {
 978        test_edit_sequence(&[
 979            "
 980                // a! {
 981                    static B: C = d;
 982                // }
 983            ",
 984            "
 985                ˇa! {
 986                    static B: C = d;
 987                ˇ}
 988            ",
 989        ]);
 990    }
 991
 992    fn test_edit_sequence(steps: &[&str]) -> (Buffer, SyntaxMap) {
 993        let registry = Arc::new(LanguageRegistry::test());
 994        let language = Arc::new(rust_lang());
 995        registry.add(language.clone());
 996        let mut buffer = Buffer::new(0, 0, Default::default());
 997
 998        let mut mutated_syntax_map = SyntaxMap::new();
 999        mutated_syntax_map.set_language_registry(registry.clone());
1000        mutated_syntax_map.reparse(language.clone(), &buffer);
1001
1002        for (i, marked_string) in steps.into_iter().enumerate() {
1003            edit_buffer(&mut buffer, &marked_string.unindent());
1004
1005            // Reparse the syntax map
1006            mutated_syntax_map.interpolate(&buffer);
1007            mutated_syntax_map.reparse(language.clone(), &buffer);
1008
1009            // Create a second syntax map from scratch
1010            let mut reference_syntax_map = SyntaxMap::new();
1011            reference_syntax_map.set_language_registry(registry.clone());
1012            reference_syntax_map.reparse(language.clone(), &buffer);
1013
1014            // Compare the mutated syntax map to the new syntax map
1015            let mutated_layers = mutated_syntax_map.layers(&buffer);
1016            let reference_layers = reference_syntax_map.layers(&buffer);
1017            assert_eq!(
1018                mutated_layers.len(),
1019                reference_layers.len(),
1020                "wrong number of layers at step {i}"
1021            );
1022            for (edited_layer, reference_layer) in
1023                mutated_layers.into_iter().zip(reference_layers.into_iter())
1024            {
1025                assert_eq!(
1026                    edited_layer.1.to_sexp(),
1027                    reference_layer.1.to_sexp(),
1028                    "different layer at step {i}"
1029                );
1030                assert_eq!(
1031                    edited_layer.1.range(),
1032                    reference_layer.1.range(),
1033                    "different layer at step {i}"
1034                );
1035            }
1036        }
1037
1038        (buffer, mutated_syntax_map)
1039    }
1040
1041    fn rust_lang() -> Language {
1042        Language::new(
1043            LanguageConfig {
1044                name: "Rust".into(),
1045                path_suffixes: vec!["rs".to_string()],
1046                ..Default::default()
1047            },
1048            Some(tree_sitter_rust::language()),
1049        )
1050        .with_injection_query(
1051            r#"
1052                (macro_invocation
1053                    (token_tree) @content
1054                    (#set! "language" "rust"))
1055            "#,
1056        )
1057        .unwrap()
1058    }
1059
1060    fn range_for_text(buffer: &Buffer, text: &str) -> Range<usize> {
1061        let start = buffer.as_rope().to_string().find(text).unwrap();
1062        start..start + text.len()
1063    }
1064
1065    fn assert_layers_for_range(
1066        syntax_map: &SyntaxMap,
1067        buffer: &BufferSnapshot,
1068        range: Range<Point>,
1069        expected_layers: &[&str],
1070    ) {
1071        let layers = syntax_map.layers_for_range(range, &buffer);
1072        assert_eq!(
1073            layers.len(),
1074            expected_layers.len(),
1075            "wrong number of layers"
1076        );
1077        for (i, ((_, node), expected_s_exp)) in
1078            layers.iter().zip(expected_layers.iter()).enumerate()
1079        {
1080            let actual_s_exp = node.to_sexp();
1081            assert!(
1082                string_contains_sequence(
1083                    &actual_s_exp,
1084                    &expected_s_exp.split("...").collect::<Vec<_>>()
1085                ),
1086                "layer {i}:\n\nexpected: {expected_s_exp}\nactual:   {actual_s_exp}",
1087            );
1088        }
1089    }
1090
1091    fn assert_node_ranges(
1092        syntax_map: &SyntaxMap,
1093        buffer: &BufferSnapshot,
1094        query: &str,
1095        marked_string: &str,
1096    ) {
1097        let mut cursor = QueryCursorHandle::new();
1098        let mut actual_ranges = Vec::<Range<usize>>::new();
1099        for (grammar, node) in syntax_map.layers(buffer) {
1100            let query = Query::new(grammar.ts_language, query).unwrap();
1101            for (mat, ix) in cursor.captures(&query, node, TextProvider(buffer.as_rope())) {
1102                actual_ranges.push(mat.captures[ix].node.byte_range());
1103            }
1104        }
1105
1106        let (text, expected_ranges) = marked_text_ranges(&marked_string.unindent(), false);
1107        assert_eq!(text, buffer.text());
1108        assert_eq!(actual_ranges, expected_ranges);
1109    }
1110
1111    fn edit_buffer(buffer: &mut Buffer, marked_string: &str) {
1112        let old_text = buffer.text();
1113        let (new_text, mut ranges) = marked_text_ranges(marked_string, false);
1114        if ranges.is_empty() {
1115            ranges.push(0..new_text.len());
1116        }
1117
1118        assert_eq!(
1119            old_text[..ranges[0].start],
1120            new_text[..ranges[0].start],
1121            "invalid edit"
1122        );
1123
1124        let mut delta = 0;
1125        let mut edits = Vec::new();
1126        let mut ranges = ranges.into_iter().peekable();
1127
1128        while let Some(inserted_range) = ranges.next() {
1129            let new_start = inserted_range.start;
1130            let old_start = (new_start as isize - delta) as usize;
1131
1132            let following_text = if let Some(next_range) = ranges.peek() {
1133                &new_text[inserted_range.end..next_range.start]
1134            } else {
1135                &new_text[inserted_range.end..]
1136            };
1137
1138            let inserted_len = inserted_range.len();
1139            let deleted_len = old_text[old_start..]
1140                .find(following_text)
1141                .expect("invalid edit");
1142
1143            let old_range = old_start..old_start + deleted_len;
1144            edits.push((old_range, new_text[inserted_range].to_string()));
1145            delta += inserted_len as isize - deleted_len as isize;
1146        }
1147
1148        assert_eq!(
1149            old_text.len() as isize + delta,
1150            new_text.len() as isize,
1151            "invalid edit"
1152        );
1153
1154        buffer.edit(edits);
1155    }
1156
1157    pub fn string_contains_sequence(text: &str, parts: &[&str]) -> bool {
1158        let mut last_part_end = 0;
1159        for part in parts {
1160            if let Some(start_ix) = text[last_part_end..].find(part) {
1161                last_part_end = start_ix + part.len();
1162            } else {
1163                return false;
1164            }
1165        }
1166        true
1167    }
1168}