Add `Buffer::enclosing_bracket_ranges`

Antonio Scandurra , Nathan Sobo , and Max Brunsfeld created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

zed/languages/rust/config.toml |   6 +
zed/src/editor/buffer/mod.rs   | 110 +++++++++++++++++++++++++++++++++--
zed/src/language.rs            |   7 ++
3 files changed, 116 insertions(+), 7 deletions(-)

Detailed changes

zed/languages/rust/config.toml 🔗

@@ -1,2 +1,8 @@
 name = "Rust"
 path_suffixes = ["rs"]
+bracket_pairs = [
+    { start = "{", end = "}" },
+    { start = "[", end = "]" },
+    { start = "(", end = ")" },
+    { start = "<", end = ">" },
+]

zed/src/editor/buffer/mod.rs 🔗

@@ -723,6 +723,47 @@ impl Buffer {
         }
     }
 
+    pub fn enclosing_bracket_ranges<T: ToOffset>(
+        &self,
+        range: Range<T>,
+    ) -> Option<(Range<usize>, Range<usize>)> {
+        let mut bracket_ranges = None;
+        if let Some((lang, tree)) = self.language.as_ref().zip(self.syntax_tree()) {
+            let range = range.start.to_offset(self)..range.end.to_offset(self);
+            let mut cursor = tree.root_node().walk();
+            'outer: loop {
+                let node = cursor.node();
+                if node.child_count() >= 2 {
+                    if let Some((first_child, last_child)) =
+                        node.child(0).zip(node.child(node.child_count() - 1))
+                    {
+                        for pair in &lang.config.bracket_pairs {
+                            if pair.start == first_child.kind() && pair.end == last_child.kind() {
+                                bracket_ranges =
+                                    Some((first_child.byte_range(), last_child.byte_range()));
+                            }
+                        }
+                    }
+                }
+
+                if !cursor.goto_first_child() {
+                    break;
+                }
+
+                while cursor.node().end_byte() < range.end {
+                    if !cursor.goto_next_sibling() {
+                        break 'outer;
+                    }
+                }
+
+                if cursor.node().start_byte() > range.start {
+                    break;
+                }
+            }
+        }
+        bracket_ranges
+    }
+
     fn diff(&self, new_text: Arc<str>, ctx: &AppContext) -> Task<Diff> {
         // TODO: it would be nice to not allocate here.
         let old_text = self.text();
@@ -3531,16 +3572,12 @@ mod tests {
     #[gpui::test]
     async fn test_reparse(mut ctx: gpui::TestAppContext) {
         let app_state = ctx.read(build_app_state);
-        let rust_lang = app_state
-            .language_registry
-            .select_language("test.rs")
-            .cloned();
+        let rust_lang = app_state.language_registry.select_language("test.rs");
         assert!(rust_lang.is_some());
 
         let buffer = ctx.add_model(|ctx| {
-            let text = "fn a() {}";
-
-            let buffer = Buffer::from_history(0, History::new(text.into()), None, rust_lang, ctx);
+            let text = "fn a() {}".into();
+            let buffer = Buffer::from_history(0, History::new(text), None, rust_lang.cloned(), ctx);
             assert!(buffer.is_parsing());
             assert!(buffer.syntax_tree().is_none());
             buffer
@@ -3671,6 +3708,54 @@ mod tests {
         }
     }
 
+    #[gpui::test]
+    async fn test_enclosing_bracket_ranges(mut ctx: gpui::TestAppContext) {
+        use unindent::Unindent as _;
+
+        let app_state = ctx.read(build_app_state);
+        let rust_lang = app_state.language_registry.select_language("test.rs");
+        assert!(rust_lang.is_some());
+
+        let buffer = ctx.add_model(|ctx| {
+            let text = "
+                mod x {
+                    mod y {
+                        
+                    }
+                }
+            "
+            .unindent()
+            .into();
+            Buffer::from_history(0, History::new(text), None, rust_lang.cloned(), ctx)
+        });
+        buffer
+            .condition(&ctx, |buffer, _| !buffer.is_parsing())
+            .await;
+        buffer.read_with(&ctx, |buf, _| {
+            assert_eq!(
+                buf.enclosing_bracket_point_ranges(Point::new(1, 6)..Point::new(1, 6)),
+                Some((
+                    Point::new(0, 6)..Point::new(0, 7),
+                    Point::new(4, 0)..Point::new(4, 1)
+                ))
+            );
+            assert_eq!(
+                buf.enclosing_bracket_point_ranges(Point::new(1, 10)..Point::new(1, 10)),
+                Some((
+                    Point::new(1, 10)..Point::new(1, 11),
+                    Point::new(3, 4)..Point::new(3, 5)
+                ))
+            );
+            assert_eq!(
+                buf.enclosing_bracket_point_ranges(Point::new(3, 5)..Point::new(3, 5)),
+                Some((
+                    Point::new(1, 10)..Point::new(1, 11),
+                    Point::new(3, 4)..Point::new(3, 5)
+                ))
+            );
+        });
+    }
+
     impl Buffer {
         fn random_byte_range(&mut self, start_offset: usize, rng: &mut impl Rng) -> Range<usize> {
             let end = self.clip_offset(rng.gen_range(start_offset..=self.len()), Bias::Right);
@@ -3817,6 +3902,17 @@ mod tests {
                 .keys()
                 .map(move |set_id| (*set_id, self.selection_ranges(*set_id).unwrap()))
         }
+
+        pub fn enclosing_bracket_point_ranges<T: ToOffset>(
+            &self,
+            range: Range<T>,
+        ) -> Option<(Range<Point>, Range<Point>)> {
+            self.enclosing_bracket_ranges(range).map(|(start, end)| {
+                let point_start = start.start.to_point(self)..start.end.to_point(self);
+                let point_end = end.start.to_point(self)..end.end.to_point(self);
+                (point_start, point_end)
+            })
+        }
     }
 
     impl Operation {

zed/src/language.rs 🔗

@@ -14,6 +14,13 @@ pub struct LanguageDir;
 pub struct LanguageConfig {
     pub name: String,
     pub path_suffixes: Vec<String>,
+    pub bracket_pairs: Vec<BracketPair>,
+}
+
+#[derive(Deserialize)]
+pub struct BracketPair {
+    pub start: String,
+    pub end: String,
 }
 
 pub struct Language {