Start on a query-based autoindent system

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/languages/rust/indent.scm     |   7 
zed/src/editor/buffer/mod.rs      | 266 +++++++++++++++++++++-----------
zed/src/editor/display_map/mod.rs |   5 
zed/src/language.rs               |  21 +
4 files changed, 195 insertions(+), 104 deletions(-)

Detailed changes

zed/languages/rust/indent.scm 🔗

@@ -0,0 +1,7 @@
+(where_clause) @indent
+
+(field_expression) @indent
+
+(_ "(" ")" @outdent) @indent
+(_ "[" "]" @outdent) @indent
+(_ "{" "}" @outdent) @indent

zed/src/editor/buffer/mod.rs 🔗

@@ -516,12 +516,11 @@ impl Buffer {
     }
 
     pub fn snapshot(&self) -> Snapshot {
-        let mut cursors = QUERY_CURSORS.lock();
         Snapshot {
             text: self.visible_text.clone(),
             tree: self.syntax_tree(),
             language: self.language.clone(),
-            query_cursor: Some(cursors.pop().unwrap_or_else(|| QueryCursor::new())),
+            query_cursor: Some(acquire_query_cursor()),
         }
     }
 
@@ -691,73 +690,99 @@ impl Buffer {
     }
 
     fn autoindent_for_rows(&self, rows: Range<u32>) -> Vec<usize> {
-        let mut indents = Vec::new();
-        if let Some((language, syntax_tree)) = self.language.as_ref().zip(self.syntax_tree()) {
-            let mut stack = Vec::new();
-            let mut cursor = syntax_tree.walk();
-            let mut row = rows.start;
-            while row < rows.end {
-                let node = cursor.node();
-                let row_start = Point::new(row, 0).into();
-
-                if node.end_position() <= row_start {
-                    if !cursor.goto_next_sibling() {
-                        if stack.last() == Some(&node) {
-                            stack.pop();
-                        }
+        // Find the indentation level of the previous non-whitespace row.
+        let mut prev_row = rows.start;
+        let prev_indent = loop {
+            if prev_row == 0 {
+                break 0;
+            }
+            prev_row -= 1;
+            let (indent, is_whitespace) = self.indent_for_row(prev_row);
+            if !is_whitespace {
+                break indent;
+            }
+        };
 
-                        if !cursor.goto_parent() {
-                            break;
-                        }
-                    }
-                } else if node.start_position() <= row_start && cursor.goto_first_child() {
-                    if language.config.indent_nodes.contains(node.kind()) {
-                        stack.push(node);
-                    }
-                } else {
-                    let mut indented = false;
-                    for ancestor in stack.iter().rev() {
-                        let ancestor_start_row = ancestor.start_position().row as u32;
-                        if ancestor_start_row < row {
-                            let ancestor_indent = if ancestor_start_row < rows.start {
-                                self.indent_for_row(ancestor_start_row).0
-                            } else {
-                                indents[(ancestor_start_row - rows.start) as usize]
-                            };
+        let (language, syntax_tree) = match self.language.as_ref().zip(self.syntax_tree()) {
+            Some(e) => e,
+            None => return vec![prev_indent; rows.len()],
+        };
 
-                            if ancestor.end_position().row as u32 == row {
-                                indents.push(ancestor_indent);
-                            } else {
-                                indents.push(ancestor_indent + language.config.indent);
-                            }
+        // Find the capture indices in the language's indent query that represent increased
+        // and decreased indentation.
+        let mut indent_capture_ix = u32::MAX;
+        let mut outdent_capture_ix = u32::MAX;
+        for (ix, name) in language.indent_query.capture_names().iter().enumerate() {
+            match name.as_ref() {
+                "indent" => indent_capture_ix = ix as u32,
+                "outdent" => outdent_capture_ix = ix as u32,
+                _ => continue,
+            }
+        }
 
-                            indented = true;
-                            break;
-                        }
-                    }
+        let start_row = rows.start as usize;
+        let mut indents = vec![prev_indent; rows.len()];
 
-                    if !indented {
-                        let mut indent = 0;
-                        for prev_row in (0..row).rev() {
-                            if prev_row < rows.start {
-                                let (prev_indent, is_whitespace) = self.indent_for_row(prev_row);
-                                if prev_indent != 0 || !is_whitespace {
-                                    indent = prev_indent;
-                                    break;
-                                }
-                            } else {
-                                indent = indents[(prev_row - rows.start) as usize];
-                                break;
-                            }
-                        }
-                        indents.push(indent);
+        // Find all of the indent and outdent nodes in the given row range.
+        let mut cursor = acquire_query_cursor();
+        cursor.set_point_range(
+            Point::new(prev_row, 0).into(),
+            Point::new(rows.end + 1, 0).into(),
+        );
+        for mat in cursor.matches(
+            &language.indent_query,
+            syntax_tree.root_node(),
+            TextProvider(&self.visible_text),
+        ) {
+            for capture in mat.captures {
+                if capture.index == indent_capture_ix {
+                    let node_start_row = capture.node.start_position().row;
+                    let node_end_row = capture.node.end_position().row;
+                    let start_ix = (node_start_row + 1).saturating_sub(start_row);
+                    let end_ix = (node_end_row + 1).saturating_sub(start_row);
+                    for ix in start_ix..end_ix {
+                        indents[ix] += language.config.indent;
                     }
-
-                    row += 1;
                 }
             }
-        } else {
-            panic!()
+            for capture in mat.captures {
+                if capture.index == outdent_capture_ix {
+                    let node_start_row = capture.node.start_position().row;
+                    let node_end_row = capture.node.end_position().row;
+                    let start_ix = node_start_row.saturating_sub(start_row);
+                    let end_ix = (node_end_row + 1).saturating_sub(start_row);
+                    for ix in start_ix..end_ix {
+                        indents[ix] = indents[ix].saturating_sub(language.config.indent);
+                    }
+                }
+            }
+        }
+
+        // Post-process indents to fix doubly-indented lines.
+        struct Indent {
+            initial: usize,
+            adjusted: usize,
+        }
+        let mut indent_stack = vec![Indent {
+            initial: prev_indent,
+            adjusted: prev_indent,
+        }];
+        for indent in indents.iter_mut() {
+            while *indent < indent_stack.last().unwrap().initial {
+                indent_stack.pop();
+            }
+
+            let cur_indent = indent_stack.last().unwrap();
+            if *indent > cur_indent.initial {
+                let adjusted_indent = cur_indent.adjusted + language.config.indent;
+                indent_stack.push(Indent {
+                    initial: *indent,
+                    adjusted: adjusted_indent,
+                });
+                *indent = adjusted_indent;
+            } else {
+                *indent = cur_indent.adjusted;
+            }
         }
 
         indents
@@ -2201,10 +2226,21 @@ impl Snapshot {
 
 impl Drop for Snapshot {
     fn drop(&mut self) {
-        QUERY_CURSORS.lock().push(self.query_cursor.take().unwrap());
+        release_query_cursor(self.query_cursor.take().unwrap());
     }
 }
 
+fn acquire_query_cursor() -> QueryCursor {
+    QUERY_CURSORS
+        .lock()
+        .pop()
+        .unwrap_or_else(|| QueryCursor::new())
+}
+
+fn release_query_cursor(cursor: QueryCursor) {
+    QUERY_CURSORS.lock().push(cursor)
+}
+
 struct RopeBuilder<'a> {
     old_visible_cursor: rope::Cursor<'a>,
     old_deleted_cursor: rope::Cursor<'a>,
@@ -2731,7 +2767,6 @@ mod tests {
         cell::RefCell,
         cmp::Ordering,
         fs,
-        iter::FromIterator as _,
         rc::Rc,
         sync::atomic::{self, AtomicUsize},
     };
@@ -3722,49 +3757,90 @@ mod tests {
     }
 
     #[gpui::test]
-    async fn test_indent(mut ctx: gpui::TestAppContext) {
+    async fn test_autoindent(mut ctx: gpui::TestAppContext) {
         let grammar = tree_sitter_rust::language();
         let lang = Arc::new(Language {
             config: LanguageConfig {
-                indent: 3,
-                indent_nodes: std::collections::HashSet::from_iter(vec!["block".to_string()]),
+                indent: 4,
                 ..Default::default()
             },
             grammar: grammar.clone(),
             highlight_query: tree_sitter::Query::new(grammar, "").unwrap(),
+            indent_query: tree_sitter::Query::new(
+                grammar,
+                r#"
+                    (block "}" @outdent) @indent
+                    (_ ")" @outdent) @indent
+                    (where_clause) @indent
+                "#,
+            )
+            .unwrap(),
             theme_mapping: Default::default(),
         });
 
-        let buffer = ctx.add_model(|ctx| {
-            let text = "
-                fn a() {}
-
-                 fn b() {
-                 }
-
-                fn c() {
-                 let badly_indented_line;
-
-                }
-
-                struct D { // we deliberately don't auto-indent structs for this example
-                    x: 1,
-
+        let examples = vec![
+            "
+            fn a() {
+                b(
+                    c,
+                    d
+                )
+                e(|f| {
+                    g();
+                    h(|| {
+                        i();
+                    })
+                    j();
+                });
+                k();
+            }
+            "
+            .unindent(),
+            "
+            fn a<B, C>(
+                d: e
+            ) -> D
+            where
+                B: E,
+                C: F
+            {
+                
+            }
+            "
+            .unindent(),
+        ];
 
+        for (example_ix, text) in examples.into_iter().enumerate() {
+            let buffer = ctx.add_model(|ctx| {
+                Buffer::from_history(0, History::new(text.into()), None, Some(lang.clone()), ctx)
+            });
+            buffer.condition(&ctx, |buf, _| !buf.is_parsing()).await;
+
+            buffer.read_with(&ctx, |buffer, _| {
+                let row_range = 0..buffer.row_count();
+                let current_indents = row_range
+                    .clone()
+                    .map(|row| buffer.indent_for_row(row).0)
+                    .collect::<Vec<_>>();
+                let autoindents = buffer.autoindent_for_rows(row_range);
+                assert_eq!(
+                    autoindents.len(),
+                    current_indents.len(),
+                    "wrong number of autoindents returned for example {}",
+                    example_ix
+                );
+                for (row, indent) in autoindents.into_iter().enumerate() {
+                    assert_eq!(
+                        indent,
+                        current_indents[row],
+                        "wrong autoindent for example {}, row {}, line {:?}",
+                        example_ix,
+                        row,
+                        buffer.text().split('\n').skip(row).next().unwrap(),
+                    );
                 }
-                "
-            .unindent();
-            Buffer::from_history(0, History::new(text.into()), None, Some(lang), ctx)
-        });
-
-        buffer.condition(&ctx, |buf, _| !buf.is_parsing()).await;
-        buffer.read_with(&ctx, |buf, _| {
-            assert_eq!(
-                buf.autoindent_for_rows(0..buf.row_count()),
-                vec![0, 0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0]
-            );
-            todo!("write assertions to test how indents work with different subset of rows");
-        });
+            });
+        }
     }
 
     impl Buffer {

zed/src/editor/display_map/mod.rs 🔗

@@ -509,7 +509,7 @@ mod tests {
                 fn inner() {}
             }"#
         .unindent();
-        let query = tree_sitter::Query::new(
+        let highlight_query = tree_sitter::Query::new(
             grammar,
             r#"
             (mod_item name: (identifier) body: _ @mod.body)
@@ -530,7 +530,8 @@ mod tests {
                 ..Default::default()
             },
             grammar: grammar.clone(),
-            highlight_query: query,
+            highlight_query,
+            indent_query: tree_sitter::Query::new(grammar, "").unwrap(),
             theme_mapping: Default::default(),
         });
         lang.set_theme(&theme);

zed/src/language.rs 🔗

@@ -2,7 +2,7 @@ use crate::settings::{Theme, ThemeMap};
 use parking_lot::Mutex;
 use rust_embed::RustEmbed;
 use serde::Deserialize;
-use std::{collections::HashSet, path::Path, str, sync::Arc};
+use std::{path::Path, str, sync::Arc};
 use tree_sitter::{Language as Grammar, Query};
 pub use tree_sitter::{Parser, Tree};
 
@@ -14,7 +14,6 @@ pub struct LanguageDir;
 pub struct LanguageConfig {
     pub name: String,
     pub indent: usize,
-    pub indent_nodes: HashSet<String>,
     pub path_suffixes: Vec<String>,
 }
 
@@ -22,6 +21,7 @@ pub struct Language {
     pub config: LanguageConfig,
     pub grammar: Grammar,
     pub highlight_query: Query,
+    pub indent_query: Query,
     pub theme_mapping: Mutex<ThemeMap>,
 }
 
@@ -46,11 +46,8 @@ impl LanguageRegistry {
         let rust_language = Language {
             config: rust_config,
             grammar,
-            highlight_query: Query::new(
-                grammar,
-                str::from_utf8(LanguageDir::get("rust/highlights.scm").unwrap().as_ref()).unwrap(),
-            )
-            .unwrap(),
+            highlight_query: Self::load_query(grammar, "rust/highlights.scm"),
+            indent_query: Self::load_query(grammar, "rust/indents.scm"),
             theme_mapping: Mutex::new(ThemeMap::default()),
         };
 
@@ -78,6 +75,14 @@ impl LanguageRegistry {
                 .any(|suffix| path_suffixes.contains(&Some(suffix.as_str())))
         })
     }
+
+    fn load_query(grammar: tree_sitter::Language, path: &str) -> Query {
+        Query::new(
+            grammar,
+            str::from_utf8(LanguageDir::get(path).unwrap().as_ref()).unwrap(),
+        )
+        .unwrap()
+    }
 }
 
 #[cfg(test)]
@@ -97,6 +102,7 @@ mod tests {
                     },
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
+                    indent_query: Query::new(grammar, "").unwrap(),
                     theme_mapping: Default::default(),
                 }),
                 Arc::new(Language {
@@ -107,6 +113,7 @@ mod tests {
                     },
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
+                    indent_query: Query::new(grammar, "").unwrap(),
                     theme_mapping: Default::default(),
                 }),
             ],