Implement bracket matching using queries

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                        |  2 
Cargo.toml                        | 12 +++---
zed/languages/rust/brackets.scm   |  6 +++
zed/src/editor/buffer/mod.rs      | 54 ++++++++++++--------------------
zed/src/editor/buffer_view.rs     |  5 ++
zed/src/editor/display_map/mod.rs |  1 
zed/src/editor/mod.rs             | 10 +++++
zed/src/language.rs               |  5 ++
8 files changed, 51 insertions(+), 44 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2714,7 +2714,7 @@ dependencies = [
 [[package]]
 name = "tree-sitter"
 version = "0.19.5"
-source = "git+https://github.com/tree-sitter/tree-sitter?rev=036aceed574c2c23eee8f0ff90be5a2409e524c1#036aceed574c2c23eee8f0ff90be5a2409e524c1"
+source = "git+https://github.com/tree-sitter/tree-sitter?rev=97dfee63257b5e92197399b381aa993514640adf#97dfee63257b5e92197399b381aa993514640adf"
 dependencies = [
  "cc",
  "regex",

Cargo.toml 🔗

@@ -2,14 +2,14 @@
 members = ["zed", "gpui", "gpui_macros", "fsevent", "scoped_pool"]
 
 [patch.crates-io]
-async-task = {git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e"}
-tree-sitter = {git = "https://github.com/tree-sitter/tree-sitter", rev = "036aceed574c2c23eee8f0ff90be5a2409e524c1"}
+async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }
+tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "97dfee63257b5e92197399b381aa993514640adf" }
 
 # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457
-cocoa = {git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737"}
-cocoa-foundation = {git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737"}
-core-foundation = {git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737"}
-core-graphics = {git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737"}
+cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737" }
+cocoa-foundation = { git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737" }
+core-foundation = { git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737" }
+core-graphics = { git = "https://github.com/servo/core-foundation-rs", rev = "025dcb3c0d1ef01530f57ef65f3b1deb948f5737" }
 
 [profile.dev]
 split-debuginfo = "unpacked"

zed/languages/rust/brackets.scm 🔗

@@ -0,0 +1,6 @@
+("(" @open ")" @close)
+("[" @open "]" @close)
+("{" @open "}" @close)
+("<" @open ">" @close)
+("\"" @open "\"" @close)
+(closure_parameters "|" @open "|" @close)

zed/src/editor/buffer/mod.rs 🔗

@@ -727,41 +727,27 @@ impl Buffer {
         &self,
         range: Range<T>,
     ) -> Option<(Range<usize>, Range<usize>)> {
-        let mut bracket_ranges = None;
-        if let Some((lang, tree)) = self.language.as_ref().zip(self.syntax_tree()) {
-            let range = range.start.to_offset(self)..range.end.to_offset(self);
-            let mut cursor = tree.root_node().walk();
-            'outer: loop {
-                let node = cursor.node();
-                if node.child_count() >= 2 {
-                    if let Some((first_child, last_child)) =
-                        node.child(0).zip(node.child(node.child_count() - 1))
-                    {
-                        for pair in &lang.config.bracket_pairs {
-                            if pair.start == first_child.kind() && pair.end == last_child.kind() {
-                                bracket_ranges =
-                                    Some((first_child.byte_range(), last_child.byte_range()));
-                            }
-                        }
-                    }
-                }
-
-                if !cursor.goto_first_child() {
-                    break;
-                }
-
-                while cursor.node().end_byte() < range.end {
-                    if !cursor.goto_next_sibling() {
-                        break 'outer;
-                    }
-                }
+        let (lang, tree) = self.language.as_ref().zip(self.syntax_tree())?;
+        let open_capture_ix = lang.brackets_query.capture_index_for_name("open")?;
+        let close_capture_ix = lang.brackets_query.capture_index_for_name("close")?;
+
+        // Find bracket pairs that *inclusively* contain the given range.
+        let range = range.start.to_offset(self).saturating_sub(1)..range.end.to_offset(self) + 1;
+        let mut cursor = QueryCursorHandle::new();
+        let matches = cursor.set_byte_range(range.start, range.end).matches(
+            &lang.brackets_query,
+            tree.root_node(),
+            TextProvider(&self.visible_text),
+        );
 
-                if cursor.node().start_byte() > range.start {
-                    break;
-                }
-            }
-        }
-        bracket_ranges
+        // Get the ranges of the innermost pair of brackets.
+        matches
+            .filter_map(|mat| {
+                let open = mat.nodes_for_capture_index(open_capture_ix).next()?;
+                let close = mat.nodes_for_capture_index(close_capture_ix).next()?;
+                Some((open.byte_range(), close.byte_range()))
+            })
+            .min_by_key(|(open_range, close_range)| close_range.end - open_range.start)
     }
 
     fn diff(&self, new_text: Arc<str>, ctx: &AppContext) -> Task<Diff> {

zed/src/editor/buffer_view.rs 🔗

@@ -1826,6 +1826,8 @@ impl BufferView {
     }
 
     pub fn move_to_enclosing_bracket(&mut self, _: &(), ctx: &mut ViewContext<Self>) {
+        use super::RangeExt as _;
+
         let buffer = self.buffer.read(ctx.as_ref());
         let mut selections = self.selections(ctx.as_ref()).to_vec();
         for selection in &mut selections {
@@ -1833,12 +1835,13 @@ impl BufferView {
             if let Some((open_range, close_range)) =
                 buffer.enclosing_bracket_ranges(selection_range.clone())
             {
+                let close_range = close_range.to_inclusive();
                 let destination = if close_range.contains(&selection_range.start)
                     && close_range.contains(&selection_range.end)
                 {
                     open_range.end
                 } else {
-                    close_range.start
+                    *close_range.start()
                 };
                 selection.start = buffer.anchor_before(destination);
                 selection.end = selection.start.clone();

zed/src/editor/display_map/mod.rs 🔗

@@ -535,6 +535,7 @@ mod tests {
             },
             grammar: grammar.clone(),
             highlight_query,
+            brackets_query: tree_sitter::Query::new(grammar, "").unwrap(),
             theme_mapping: Default::default(),
         });
         lang.set_theme(&theme);

zed/src/editor/mod.rs 🔗

@@ -9,7 +9,10 @@ pub use buffer_element::*;
 pub use buffer_view::*;
 pub use display_map::DisplayPoint;
 use display_map::*;
-use std::{cmp, ops::Range};
+use std::{
+    cmp,
+    ops::{Range, RangeInclusive},
+};
 
 #[derive(Copy, Clone)]
 pub enum Bias {
@@ -19,10 +22,15 @@ pub enum Bias {
 
 trait RangeExt<T> {
     fn sorted(&self) -> Range<T>;
+    fn to_inclusive(&self) -> RangeInclusive<T>;
 }
 
 impl<T: Ord + Clone> RangeExt<T> for Range<T> {
     fn sorted(&self) -> Self {
         cmp::min(&self.start, &self.end).clone()..cmp::max(&self.start, &self.end).clone()
     }
+
+    fn to_inclusive(&self) -> RangeInclusive<T> {
+        self.start.clone()..=self.end.clone()
+    }
 }

zed/src/language.rs 🔗

@@ -14,7 +14,6 @@ pub struct LanguageDir;
 pub struct LanguageConfig {
     pub name: String,
     pub path_suffixes: Vec<String>,
-    pub bracket_pairs: Vec<BracketPair>,
 }
 
 #[derive(Deserialize)]
@@ -27,6 +26,7 @@ pub struct Language {
     pub config: LanguageConfig,
     pub grammar: Grammar,
     pub highlight_query: Query,
+    pub brackets_query: Query,
     pub theme_mapping: Mutex<ThemeMap>,
 }
 
@@ -52,6 +52,7 @@ impl LanguageRegistry {
             config: rust_config,
             grammar,
             highlight_query: Self::load_query(grammar, "rust/highlights.scm"),
+            brackets_query: Self::load_query(grammar, "rust/brackets.scm"),
             theme_mapping: Mutex::new(ThemeMap::default()),
         };
 
@@ -106,6 +107,7 @@ mod tests {
                     },
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
+                    brackets_query: Query::new(grammar, "").unwrap(),
                     theme_mapping: Default::default(),
                 }),
                 Arc::new(Language {
@@ -116,6 +118,7 @@ mod tests {
                     },
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
+                    brackets_query: Query::new(grammar, "").unwrap(),
                     theme_mapping: Default::default(),
                 }),
             ],