Summarize the contents of a file using the embedding query

Antonio Scandurra created

Change summary

crates/assistant/src/assistant_panel.rs   |   1 
crates/assistant/src/prompts.rs           | 458 +++++++++++++-----------
crates/language/src/buffer.rs             |  12 
crates/zed/src/languages/rust/summary.scm |   6 
4 files changed, 253 insertions(+), 224 deletions(-)

Detailed changes

crates/assistant/src/prompts.rs 🔗

@@ -1,86 +1,118 @@
-use gpui::AppContext;
+use crate::codegen::CodegenKind;
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
 use std::cmp;
 use std::ops::Range;
 use std::{fmt::Write, iter};
 
-use crate::codegen::CodegenKind;
+fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
+    #[derive(Debug)]
+    struct Match {
+        collapse: Range<usize>,
+        keep: Vec<Range<usize>>,
+    }
 
-fn outline_for_prompt(
-    buffer: &BufferSnapshot,
-    range: Range<language::Anchor>,
-    cx: &AppContext,
-) -> Option<String> {
-    let indent = buffer
-        .language_indent_size_at(0, cx)
-        .chars()
-        .collect::<String>();
-    let outline = buffer.outline(None)?;
-    let range = range.to_offset(buffer);
-
-    let mut text = String::new();
-    let mut items = outline.items.into_iter().peekable();
-
-    let mut intersected = false;
-    let mut intersection_indent = 0;
-    let mut extended_range = range.clone();
-
-    while let Some(item) = items.next() {
-        let item_range = item.range.to_offset(buffer);
-        if item_range.end < range.start || item_range.start > range.end {
-            text.extend(iter::repeat(indent.as_str()).take(item.depth));
-            text.push_str(&item.text);
-            text.push('\n');
-        } else {
-            intersected = true;
-            let is_terminal = items
-                .peek()
-                .map_or(true, |next_item| next_item.depth <= item.depth);
-            if is_terminal {
-                if item_range.start <= extended_range.start {
-                    extended_range.start = item_range.start;
-                    intersection_indent = item.depth;
+    let selected_range = selected_range.to_offset(buffer);
+    let mut matches = buffer.matches(0..buffer.len(), |grammar| {
+        Some(&grammar.embedding_config.as_ref()?.query)
+    });
+    let configs = matches
+        .grammars()
+        .iter()
+        .map(|g| g.embedding_config.as_ref().unwrap())
+        .collect::<Vec<_>>();
+    let mut matches = iter::from_fn(move || {
+        while let Some(mat) = matches.peek() {
+            let config = &configs[mat.grammar_index];
+            if let Some(collapse) = mat.captures.iter().find_map(|cap| {
+                if Some(cap.index) == config.collapse_capture_ix {
+                    Some(cap.node.byte_range())
+                } else {
+                    None
                 }
-                extended_range.end = cmp::max(extended_range.end, item_range.end);
+            }) {
+                let mut keep = Vec::new();
+                for capture in mat.captures.iter() {
+                    if Some(capture.index) == config.keep_capture_ix {
+                        keep.push(capture.node.byte_range());
+                    } else {
+                        continue;
+                    }
+                }
+                matches.advance();
+                return Some(Match { collapse, keep });
+            } else {
+                matches.advance();
+            }
+        }
+        None
+    })
+    .peekable();
+
+    let mut summary = String::new();
+    let mut offset = 0;
+    let mut flushed_selection = false;
+    while let Some(mut mat) = matches.next() {
+        // Keep extending the collapsed range if the next match surrounds
+        // the current one.
+        while let Some(next_mat) = matches.peek() {
+            if next_mat.collapse.start <= mat.collapse.start
+                && next_mat.collapse.end >= mat.collapse.end
+            {
+                mat = matches.next().unwrap();
             } else {
-                let name_start = item_range.start + item.name_ranges.first().unwrap().start;
-                let name_end = item_range.start + item.name_ranges.last().unwrap().end;
+                break;
+            }
+        }
+
+        if offset >= mat.collapse.start {
+            // Skip collapsed nodes that have already been summarized.
+            offset = cmp::max(offset, mat.collapse.end);
+            continue;
+        }
 
-                if range.start > name_end {
-                    text.extend(iter::repeat(indent.as_str()).take(item.depth));
-                    text.push_str(&item.text);
-                    text.push('\n');
+        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
+            if !flushed_selection {
+                // The collapsed node ends after the selection starts, so we'll flush the selection first.
+                summary.extend(buffer.text_for_range(offset..selected_range.start));
+                summary.push_str("<|START|");
+                if selected_range.end == selected_range.start {
+                    summary.push_str(">");
                 } else {
-                    if name_start <= extended_range.start {
-                        extended_range.start = item_range.start;
-                        intersection_indent = item.depth;
-                    }
-                    extended_range.end = cmp::max(extended_range.end, name_end);
+                    summary.extend(buffer.text_for_range(selected_range.clone()));
+                    summary.push_str("|END|>");
                 }
+                offset = selected_range.end;
+                flushed_selection = true;
             }
-        }
 
-        if intersected
-            && items.peek().map_or(true, |next_item| {
-                next_item.range.start.to_offset(buffer) > range.end
-            })
-        {
-            intersected = false;
-            text.extend(iter::repeat(indent.as_str()).take(intersection_indent));
-            text.extend(buffer.text_for_range(extended_range.start..range.start));
-            text.push_str("<|START|");
-            text.extend(buffer.text_for_range(range.clone()));
-            if range.start != range.end {
-                text.push_str("|END|>");
-            } else {
-                text.push_str(">");
+            // If the selection intersects the collapsed node, we won't collapse it.
+            if selected_range.end >= mat.collapse.start {
+                continue;
             }
-            text.extend(buffer.text_for_range(range.end..extended_range.end));
-            text.push('\n');
         }
+
+        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
+        for keep in mat.keep {
+            summary.extend(buffer.text_for_range(keep));
+        }
+        offset = mat.collapse.end;
+    }
+
+    // Flush selection if we haven't already done so.
+    if !flushed_selection && offset <= selected_range.start {
+        summary.extend(buffer.text_for_range(offset..selected_range.start));
+        summary.push_str("<|START|");
+        if selected_range.end == selected_range.start {
+            summary.push_str(">");
+        } else {
+            summary.extend(buffer.text_for_range(selected_range.clone()));
+            summary.push_str("|END|>");
+        }
+        offset = selected_range.end;
     }
 
-    Some(text)
+    summary.extend(buffer.text_for_range(offset..buffer.len()));
+    summary
 }
 
 pub fn generate_content_prompt(
@@ -88,7 +120,6 @@ pub fn generate_content_prompt(
     language_name: Option<&str>,
     buffer: &BufferSnapshot,
     range: Range<language::Anchor>,
-    cx: &AppContext,
     kind: CodegenKind,
 ) -> String {
     let mut prompt = String::new();
@@ -100,19 +131,17 @@ pub fn generate_content_prompt(
         writeln!(prompt, "You're an expert engineer.\n").unwrap();
     }
 
-    let outline = outline_for_prompt(buffer, range.clone(), cx);
-    if let Some(outline) = outline {
-        writeln!(
-            prompt,
-            "The file you are currently working on has the following outline:"
-        )
-        .unwrap();
-        if let Some(language_name) = language_name {
-            let language_name = language_name.to_lowercase();
-            writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
-        } else {
-            writeln!(prompt, "```\n{outline}\n```").unwrap();
-        }
+    let outline = summarize(buffer, range.clone());
+    writeln!(
+        prompt,
+        "The file you are currently working on has the following outline:"
+    )
+    .unwrap();
+    if let Some(language_name) = language_name {
+        let language_name = language_name.to_lowercase();
+        writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
+    } else {
+        writeln!(prompt, "```\n{outline}\n```").unwrap();
     }
 
     // Assume for now that we are just generating
@@ -183,39 +212,37 @@ pub(crate) mod tests {
             },
             Some(tree_sitter_rust::language()),
         )
-        .with_indents_query(
+        .with_embedding_query(
             r#"
-                (call_expression) @indent
-                (field_expression) @indent
-                (_ "(" ")" @end) @indent
-                (_ "{" "}" @end) @indent
-                "#,
-        )
-        .unwrap()
-        .with_outline_query(
-            r#"
-                (struct_item
-                    "struct" @context
-                    name: (_) @name) @item
-                (enum_item
-                    "enum" @context
-                    name: (_) @name) @item
-                (enum_variant
-                    name: (_) @name) @item
-                (field_declaration
-                    name: (_) @name) @item
-                (impl_item
-                    "impl" @context
-                    trait: (_)? @name
-                    "for"? @context
-                    type: (_) @name) @item
-                (function_item
-                    "fn" @context
-                    name: (_) @name) @item
-                (mod_item
-                    "mod" @context
-                    name: (_) @name) @item
-                "#,
+            (
+                [(line_comment) (attribute_item)]* @context
+                .
+                [
+                    (struct_item
+                        name: (_) @name)
+
+                    (enum_item
+                        name: (_) @name)
+
+                    (impl_item
+                        trait: (_)? @name
+                        "for"? @name
+                        type: (_) @name)
+
+                    (trait_item
+                        name: (_) @name)
+
+                    (function_item
+                        name: (_) @name
+                        body: (block
+                            "{" @keep
+                            "}" @keep) @collapse)
+
+                    (macro_definition
+                        name: (_) @name)
+                    ] @item
+                )
+            "#,
         )
         .unwrap()
     }
@@ -251,132 +278,133 @@ pub(crate) mod tests {
             cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
         let snapshot = buffer.read(cx).snapshot();
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)),
-            cx,
-        );
         assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    <|START|>a: usize
-                    b
-                impl X
-                    fn new
-                    fn a
-                    fn b
-            "})
-        );
+            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
+            indoc! {"
+                struct X {
+                    <|START|>a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {}
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)),
-            cx,
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
         );
+
         assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    a
-                    b
-                impl X
+            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
                     fn new() -> Self {
                         let <|START|a |END|>= 1;
                         let b = 2;
                         Self { a, b }
                     }
-                    fn a
-                    fn b
-            "})
-        );
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)),
-            cx,
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
         );
+
         assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    a
-                    b
-                impl X
+            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
                 <|START|>
-                    fn new
-                    fn a
-                    fn b
-            "})
-        );
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)),
-            cx,
+                    pub fn b(&self) -> usize {}
+                }
+            "}
         );
+
         assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    a
-                    b
-                impl X
-                    fn new() -> Self {
-                        let <|START|a = 1;
-                        let b = 2;
-                        Self { a, b }
-                    }
+            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
 
-                    pub f|END|>n a(&self, param: bool) -> usize {
-                        self.a
-                    }
-                    fn b
-            "})
-        );
+                impl X {
+
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)),
-            cx,
+                    pub fn b(&self) -> usize {}
+                }
+                <|START|>"}
         );
-        assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    a
-                    b
-                impl X<|START| {
 
-                    fn new() -> Self {
-                        let a = 1;
-                        let b = 2;
-                        Self { a, b }
+        // Ensure nested functions get collapsed properly.
+        let text = indoc! {"
+            struct X {
+                a: usize,
+                b: usize,
+            }
+
+            impl X {
+
+                fn new() -> Self {
+                    let a = 1;
+                    let b = 2;
+                    Self { a, b }
+                }
+
+                pub fn a(&self, param: bool) -> usize {
+                    let a = 30;
+                    fn nested() -> usize {
+                        3
                     }
-                |END|>
-                    fn a
-                    fn b
-            "})
-        );
+                    self.a + nested()
+                }
 
-        let outline = outline_for_prompt(
-            &snapshot,
-            snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)),
-            cx,
-        );
+                pub fn b(&self) -> usize {
+                    self.b
+                }
+            }
+        "};
+        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
+        let snapshot = buffer.read(cx).snapshot();
         assert_eq!(
-            outline.as_deref(),
-            Some(indoc! {"
-                struct X
-                    a
-                    b
-                impl X
-                    fn new
-                    fn a
-                    pub fn b(&self) -> usize {
-                        <|START|>self.b
-                    }
-            "})
+            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
+            indoc! {"
+                <|START|>struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
         );
     }
 }

crates/language/src/buffer.rs 🔗

@@ -8,8 +8,8 @@ use crate::{
     language_settings::{language_settings, LanguageSettings},
     outline::OutlineItem,
     syntax_map::{
-        SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxSnapshot,
-        ToTreeSitterPoint,
+        SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches,
+        SyntaxSnapshot, ToTreeSitterPoint,
     },
     CodeLabel, LanguageScope, Outline,
 };
@@ -2467,6 +2467,14 @@ impl BufferSnapshot {
         Some(items)
     }
 
+    pub fn matches(
+        &self,
+        range: Range<usize>,
+        query: fn(&Grammar) -> Option<&tree_sitter::Query>,
+    ) -> SyntaxMapMatches {
+        self.syntax.matches(range, self, query)
+    }
+
     /// Returns bracket range pairs overlapping or adjacent to `range`
     pub fn bracket_ranges<'a, T: ToOffset>(
         &'a self,