Use new Tree-sitter captures API

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                             |  2 
Cargo.toml                             |  2 
zed/assets/themes/light.toml           |  2 
zed/languages/rust/highlights.scm      |  4 +-
zed/src/editor/buffer/mod.rs           | 32 +++++++++++++++++++++-------
zed/src/editor/buffer_view.rs          |  2 
zed/src/editor/display_map/fold_map.rs | 22 ++++++++++++++-----
zed/src/editor/display_map/mod.rs      | 27 ++++++++++++++---------
8 files changed, 62 insertions(+), 31 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2714,7 +2714,7 @@ dependencies = [
 [[package]]
 name = "tree-sitter"
 version = "0.19.5"
-source = "git+https://github.com/tree-sitter/tree-sitter?rev=a61f25bc58e3affe81aaacaaf5d9b6150a5e90ef#a61f25bc58e3affe81aaacaaf5d9b6150a5e90ef"
+source = "git+https://github.com/tree-sitter/tree-sitter?rev=3112fa546852785151700d3b2fc598599ef75063#3112fa546852785151700d3b2fc598599ef75063"
 dependencies = [
  "cc",
  "regex",

Cargo.toml 🔗

@@ -3,7 +3,7 @@ members = ["zed", "gpui", "gpui_macros", "fsevent", "scoped_pool"]
 
 [patch.crates-io]
 async-task = {git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e"}
-tree-sitter = {git = "https://github.com/tree-sitter/tree-sitter", rev = "a61f25bc58e3affe81aaacaaf5d9b6150a5e90ef"}
+tree-sitter = {git = "https://github.com/tree-sitter/tree-sitter", rev = "3112fa546852785151700d3b2fc598599ef75063"}
 
 # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457
 cocoa = {git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737"}

zed/assets/themes/light.toml 🔗

@@ -10,4 +10,4 @@ string = 0xa31515
 type = 0x267599
 number = 0x0d885b
 comment = 0x048204
-property = 0x001080
+property = 0x001080

zed/languages/rust/highlights.scm 🔗

@@ -1,5 +1,7 @@
 (type_identifier) @type
 
+(field_identifier) @property
+
 (call_expression
   function: [
     (identifier) @function
@@ -9,8 +11,6 @@
       field: (field_identifier) @function.method)
   ])
 
-(field_identifier) @property
-
 (function_item
   name: (identifier) @function.definition)
 

zed/src/editor/buffer/mod.rs 🔗

@@ -2069,12 +2069,12 @@ impl Snapshot {
         let chunks = self.text.chunks_in_range(range.clone());
         if let Some((language, tree)) = self.language.as_ref().zip(self.tree.as_ref()) {
             let query_cursor = self.query_cursor.as_mut().unwrap();
-            query_cursor.set_byte_range(range.start, range.end);
-            let captures = query_cursor.captures(
+            let mut captures = query_cursor.captures(
                 &language.highlight_query,
                 tree.root_node(),
                 TextProvider(&self.text),
             );
+            captures.set_byte_range(range.start, range.end);
 
             HighlightedChunks {
                 range,
@@ -2277,9 +2277,25 @@ impl<'a> HighlightedChunks<'a> {
         self.range.start = offset;
         self.chunks.seek(self.range.start);
         if let Some(highlights) = self.highlights.as_mut() {
-            highlights.stack.clear();
-            highlights.next_capture.take();
-            highlights.captures.advance_to_byte(self.range.start);
+            highlights
+                .stack
+                .retain(|(end_offset, _)| *end_offset > offset);
+            if let Some((mat, capture_ix)) = &highlights.next_capture {
+                let capture = mat.captures[*capture_ix as usize];
+                if offset >= capture.node.start_byte() {
+                    let next_capture_end = capture.node.end_byte();
+                    if offset < next_capture_end {
+                        highlights.stack.push((
+                            next_capture_end,
+                            highlights.theme_mapping.get(capture.index),
+                        ));
+                    }
+                    highlights.next_capture.take();
+                }
+            }
+            highlights
+                .captures
+                .set_byte_range(self.range.start, self.range.end);
         }
     }
 
@@ -2323,12 +2339,12 @@ impl<'a> Iterator for HighlightedChunks<'a> {
         if let Some(chunk) = self.chunks.peek() {
             let chunk_start = self.range.start;
             let mut chunk_end = (self.chunks.offset() + chunk.len()).min(next_capture_start);
-            let mut capture_ix = StyleId::default();
+            let mut style_id = StyleId::default();
             if let Some((parent_capture_end, parent_style_id)) =
                 self.highlights.as_ref().and_then(|h| h.stack.last())
             {
                 chunk_end = chunk_end.min(*parent_capture_end);
-                capture_ix = *parent_style_id;
+                style_id = *parent_style_id;
             }
 
             let slice =
@@ -2338,7 +2354,7 @@ impl<'a> Iterator for HighlightedChunks<'a> {
                 self.chunks.next().unwrap();
             }
 
-            Some((slice, capture_ix))
+            Some((slice, style_id))
         } else {
             None
         }

zed/src/editor/buffer_view.rs 🔗

@@ -2149,7 +2149,7 @@ impl BufferView {
         let mut styles = Vec::new();
         let mut row = rows.start;
         let mut snapshot = self.display_map.snapshot(ctx);
-        let chunks = snapshot.highlighted_chunks_at(rows.start);
+        let chunks = snapshot.highlighted_chunks_for_rows(rows.clone());
         let theme = settings.theme.clone();
 
         'outer: for (chunk, style_ix) in chunks.chain(Some(("\n", StyleId::default()))) {

zed/src/editor/display_map/fold_map.rs 🔗

@@ -413,6 +413,10 @@ impl FoldMapSnapshot {
         }
     }
 
+    pub fn max_point(&self) -> DisplayPoint {
+        DisplayPoint(self.transforms.summary().display.lines)
+    }
+
     pub fn chunks_at(&self, offset: DisplayOffset) -> Chunks {
         let mut transform_cursor = self.transforms.cursor::<DisplayOffset, TransformSummary>();
         transform_cursor.seek(&offset, SeekBias::Right, &());
@@ -425,17 +429,23 @@ impl FoldMapSnapshot {
         }
     }
 
-    pub fn highlighted_chunks_at(&mut self, offset: DisplayOffset) -> HighlightedChunks {
+    pub fn highlighted_chunks(&mut self, range: Range<DisplayOffset>) -> HighlightedChunks {
         let mut transform_cursor = self.transforms.cursor::<DisplayOffset, TransformSummary>();
-        transform_cursor.seek(&offset, SeekBias::Right, &());
-        let overshoot = offset.0 - transform_cursor.start().display.bytes;
-        let buffer_offset = transform_cursor.start().buffer.bytes + overshoot;
+
+        transform_cursor.seek(&range.end, SeekBias::Right, &());
+        let overshoot = range.end.0 - transform_cursor.start().display.bytes;
+        let buffer_end = transform_cursor.start().buffer.bytes + overshoot;
+
+        transform_cursor.seek(&range.start, SeekBias::Right, &());
+        let overshoot = range.start.0 - transform_cursor.start().display.bytes;
+        let buffer_start = transform_cursor.start().buffer.bytes + overshoot;
+
         HighlightedChunks {
             transform_cursor,
-            buffer_offset,
+            buffer_offset: buffer_start,
             buffer_chunks: self
                 .buffer
-                .highlighted_text_for_range(buffer_offset..self.buffer.len()),
+                .highlighted_text_for_range(buffer_start..buffer_end),
             buffer_chunk: None,
         }
     }

zed/src/editor/display_map/mod.rs 🔗

@@ -104,9 +104,8 @@ impl DisplayMap {
             .column()
     }
 
-    // TODO - make this delegate to the DisplayMapSnapshot
     pub fn max_point(&self, ctx: &AppContext) -> DisplayPoint {
-        self.fold_map.max_point(ctx).expand_tabs(self, ctx)
+        self.snapshot(ctx).max_point().expand_tabs(self, ctx)
     }
 
     pub fn longest_row(&self, ctx: &AppContext) -> u32 {
@@ -136,6 +135,10 @@ impl DisplayMapSnapshot {
         self.folds_snapshot.buffer_rows(start_row)
     }
 
+    pub fn max_point(&self) -> DisplayPoint {
+        self.expand_tabs(self.folds_snapshot.max_point())
+    }
+
     pub fn chunks_at(&self, point: DisplayPoint) -> Chunks {
         let (point, expanded_char_column, to_next_stop) = self.collapse_tabs(point, Bias::Left);
         let fold_chunks = self
@@ -150,11 +153,13 @@ impl DisplayMapSnapshot {
         }
     }
 
-    pub fn highlighted_chunks_at(&mut self, row: u32) -> HighlightedChunks {
-        let point = DisplayPoint::new(row, 0);
-        let offset = self.folds_snapshot.to_display_offset(point);
+    pub fn highlighted_chunks_for_rows(&mut self, rows: Range<u32>) -> HighlightedChunks {
+        let start = DisplayPoint::new(rows.start, 0);
+        let start = self.folds_snapshot.to_display_offset(start);
+        let end = DisplayPoint::new(rows.end, 0).min(self.max_point());
+        let end = self.folds_snapshot.to_display_offset(end);
         HighlightedChunks {
-            fold_chunks: self.folds_snapshot.highlighted_chunks_at(offset),
+            fold_chunks: self.folds_snapshot.highlighted_chunks(start..end),
             column: 0,
             tab_size: self.tab_size,
             chunk: "",
@@ -530,7 +535,7 @@ mod tests {
 
         let mut map = app.read(|ctx| DisplayMap::new(buffer, 2, ctx));
         assert_eq!(
-            app.read(|ctx| highlighted_chunks(0, &map, &theme, ctx)),
+            app.read(|ctx| highlighted_chunks(0..5, &map, &theme, ctx)),
             vec![
                 ("fn ".to_string(), None),
                 ("outer".to_string(), Some("fn.name")),
@@ -541,7 +546,7 @@ mod tests {
             ]
         );
         assert_eq!(
-            app.read(|ctx| highlighted_chunks(3, &map, &theme, ctx)),
+            app.read(|ctx| highlighted_chunks(3..5, &map, &theme, ctx)),
             vec![
                 ("    fn ".to_string(), Some("mod.body")),
                 ("inner".to_string(), Some("fn.name")),
@@ -551,7 +556,7 @@ mod tests {
 
         app.read(|ctx| map.fold(vec![Point::new(0, 6)..Point::new(3, 2)], ctx));
         assert_eq!(
-            app.read(|ctx| highlighted_chunks(0, &map, &theme, ctx)),
+            app.read(|ctx| highlighted_chunks(0..2, &map, &theme, ctx)),
             vec![
                 ("fn ".to_string(), None),
                 ("out".to_string(), Some("fn.name")),
@@ -563,13 +568,13 @@ mod tests {
         );
 
         fn highlighted_chunks<'a>(
-            row: u32,
+            rows: Range<u32>,
             map: &DisplayMap,
             theme: &'a Theme,
             ctx: &AppContext,
         ) -> Vec<(String, Option<&'a str>)> {
             let mut chunks: Vec<(String, Option<&str>)> = Vec::new();
-            for (chunk, style_id) in map.snapshot(ctx).highlighted_chunks_at(row) {
+            for (chunk, style_id) in map.snapshot(ctx).highlighted_chunks_for_rows(rows) {
                 let style_name = theme.syntax_style_name(style_id);
                 if let Some((last_chunk, last_style_name)) = chunks.last_mut() {
                     if style_name == *last_style_name {