Leverage embeddings query to collapse syntax nodes if not selected (#3067)

Kyle Caverly created

Reverts zed-industries/zed#3049

Change summary

crates/assistant/src/assistant.rs                 |   1 
crates/assistant/src/assistant_panel.rs           | 150 +----
crates/assistant/src/prompts.rs                   | 404 +++++++++++++++++
crates/language/src/buffer.rs                     |  12 
crates/semantic_index/src/semantic_index_tests.rs |  17 
crates/zed/src/languages/rust/embedding.scm       |   4 
6 files changed, 481 insertions(+), 107 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
     codegen::{self, Codegen, CodegenKind},
+    prompts::generate_content_prompt,
     MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
     SavedMessage,
 };
@@ -273,13 +274,17 @@ impl AssistantPanel {
             return;
         };
 
+        let selection = editor.read(cx).selections.newest_anchor().clone();
+        if selection.start.excerpt_id() != selection.end.excerpt_id() {
+            return;
+        }
+
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
         let provider = Arc::new(OpenAICompletionProvider::new(
             api_key,
             cx.background().clone(),
         ));
-        let selection = editor.read(cx).selections.newest_anchor().clone();
         let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
             CodegenKind::Generate {
                 position: selection.start,
@@ -541,11 +546,26 @@ impl AssistantPanel {
             self.inline_prompt_history.pop_front();
         }
 
+        let codegen = pending_assist.codegen.clone();
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
-        let range = pending_assist.codegen.read(cx).range();
-        let selected_text = snapshot.text_for_range(range.clone()).collect::<String>();
+        let range = codegen.read(cx).range();
+        let start = snapshot.point_to_buffer_offset(range.start);
+        let end = snapshot.point_to_buffer_offset(range.end);
+        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+            let (start_buffer, start_buffer_offset) = start;
+            let (end_buffer, end_buffer_offset) = end;
+            if start_buffer.remote_id() == end_buffer.remote_id() {
+                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+            } else {
+                self.finish_inline_assist(inline_assist_id, false, cx);
+                return;
+            }
+        } else {
+            self.finish_inline_assist(inline_assist_id, false, cx);
+            return;
+        };
 
-        let language = snapshot.language_at(range.start);
+        let language = buffer.language_at(range.start);
         let language_name = if let Some(language) = language.as_ref() {
             if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
                 None
@@ -555,96 +575,13 @@ impl AssistantPanel {
         } else {
             None
         };
-        let language_name = language_name.as_deref();
-
-        let mut prompt = String::new();
-        if let Some(language_name) = language_name {
-            writeln!(prompt, "You're an expert {language_name} engineer.").unwrap();
-        }
-        match pending_assist.codegen.read(cx).kind() {
-            CodegenKind::Transform { .. } => {
-                writeln!(
-                    prompt,
-                    "You're currently working inside an editor on this file:"
-                )
-                .unwrap();
-                if let Some(language_name) = language_name {
-                    writeln!(prompt, "```{language_name}").unwrap();
-                } else {
-                    writeln!(prompt, "```").unwrap();
-                }
-                for chunk in snapshot.text_for_range(Anchor::min()..Anchor::max()) {
-                    write!(prompt, "{chunk}").unwrap();
-                }
-                writeln!(prompt, "```").unwrap();
-
-                writeln!(
-                    prompt,
-                    "In particular, the user has selected the following text:"
-                )
-                .unwrap();
-                if let Some(language_name) = language_name {
-                    writeln!(prompt, "```{language_name}").unwrap();
-                } else {
-                    writeln!(prompt, "```").unwrap();
-                }
-                writeln!(prompt, "{selected_text}").unwrap();
-                writeln!(prompt, "```").unwrap();
-                writeln!(prompt).unwrap();
-                writeln!(
-                    prompt,
-                    "Modify the selected text given the user prompt: {user_prompt}"
-                )
-                .unwrap();
-                writeln!(
-                    prompt,
-                    "You MUST reply only with the edited selected text, not the entire file."
-                )
-                .unwrap();
-            }
-            CodegenKind::Generate { .. } => {
-                writeln!(
-                    prompt,
-                    "You're currently working inside an editor on this file:"
-                )
-                .unwrap();
-                if let Some(language_name) = language_name {
-                    writeln!(prompt, "```{language_name}").unwrap();
-                } else {
-                    writeln!(prompt, "```").unwrap();
-                }
-                for chunk in snapshot.text_for_range(Anchor::min()..range.start) {
-                    write!(prompt, "{chunk}").unwrap();
-                }
-                write!(prompt, "<|>").unwrap();
-                for chunk in snapshot.text_for_range(range.start..Anchor::max()) {
-                    write!(prompt, "{chunk}").unwrap();
-                }
-                writeln!(prompt).unwrap();
-                writeln!(prompt, "```").unwrap();
-                writeln!(
-                    prompt,
-                    "Assume the cursor is located where the `<|>` marker is."
-                )
-                .unwrap();
-                writeln!(
-                    prompt,
-                    "Text can't be replaced, so assume your answer will be inserted at the cursor."
-                )
-                .unwrap();
-                writeln!(
-                    prompt,
-                    "Complete the text given the user prompt: {user_prompt}"
-                )
-                .unwrap();
-            }
-        }
-        if let Some(language_name) = language_name {
-            writeln!(prompt, "Your answer MUST always be valid {language_name}.").unwrap();
-        }
-        writeln!(prompt, "Always wrap your response in a Markdown codeblock.").unwrap();
-        writeln!(prompt, "Never make remarks about the output.").unwrap();
 
+        let codegen_kind = codegen.read(cx).kind().clone();
+        let user_prompt = user_prompt.to_string();
+        let prompt = cx.background().spawn(async move {
+            let language_name = language_name.as_deref();
+            generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
+        });
         let mut messages = Vec::new();
         let mut model = settings::get::<AssistantSettings>(cx)
             .default_open_ai_model
@@ -660,18 +597,21 @@ impl AssistantPanel {
             model = conversation.model.clone();
         }
 
-        messages.push(RequestMessage {
-            role: Role::User,
-            content: prompt,
-        });
-        let request = OpenAIRequest {
-            model: model.full_name().into(),
-            messages,
-            stream: true,
-        };
-        pending_assist
-            .codegen
-            .update(cx, |codegen, cx| codegen.start(request, cx));
+        cx.spawn(|_, mut cx| async move {
+            let prompt = prompt.await;
+
+            messages.push(RequestMessage {
+                role: Role::User,
+                content: prompt,
+            });
+            let request = OpenAIRequest {
+                model: model.full_name().into(),
+                messages,
+                stream: true,
+            };
+            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+        })
+        .detach();
     }
 
     fn update_highlights_for_editor(

crates/assistant/src/prompts.rs 🔗

@@ -0,0 +1,404 @@
+use crate::codegen::CodegenKind;
+use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
+use std::cmp::{self, Reverse};
+use std::fmt::Write;
+use std::ops::Range;
+
+fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
+    #[derive(Debug)]
+    struct Match {
+        collapse: Range<usize>,
+        keep: Vec<Range<usize>>,
+    }
+
+    let selected_range = selected_range.to_offset(buffer);
+    let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
+        Some(&grammar.embedding_config.as_ref()?.query)
+    });
+    let configs = ts_matches
+        .grammars()
+        .iter()
+        .map(|g| g.embedding_config.as_ref().unwrap())
+        .collect::<Vec<_>>();
+    let mut matches = Vec::new();
+    while let Some(mat) = ts_matches.peek() {
+        let config = &configs[mat.grammar_index];
+        if let Some(collapse) = mat.captures.iter().find_map(|cap| {
+            if Some(cap.index) == config.collapse_capture_ix {
+                Some(cap.node.byte_range())
+            } else {
+                None
+            }
+        }) {
+            let mut keep = Vec::new();
+            for capture in mat.captures.iter() {
+                if Some(capture.index) == config.keep_capture_ix {
+                    keep.push(capture.node.byte_range());
+                } else {
+                    continue;
+                }
+            }
+            ts_matches.advance();
+            matches.push(Match { collapse, keep });
+        } else {
+            ts_matches.advance();
+        }
+    }
+    matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
+    let mut matches = matches.into_iter().peekable();
+
+    let mut summary = String::new();
+    let mut offset = 0;
+    let mut flushed_selection = false;
+    while let Some(mat) = matches.next() {
+        // Keep extending the collapsed range if the next match surrounds
+        // the current one.
+        while let Some(next_mat) = matches.peek() {
+            if mat.collapse.start <= next_mat.collapse.start
+                && mat.collapse.end >= next_mat.collapse.end
+            {
+                matches.next().unwrap();
+            } else {
+                break;
+            }
+        }
+
+        if offset > mat.collapse.start {
+            // Skip collapsed nodes that have already been summarized.
+            offset = cmp::max(offset, mat.collapse.end);
+            continue;
+        }
+
+        if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
+            if !flushed_selection {
+                // The collapsed node ends after the selection starts, so we'll flush the selection first.
+                summary.extend(buffer.text_for_range(offset..selected_range.start));
+                summary.push_str("<|START|");
+                if selected_range.end == selected_range.start {
+                    summary.push_str(">");
+                } else {
+                    summary.extend(buffer.text_for_range(selected_range.clone()));
+                    summary.push_str("|END|>");
+                }
+                offset = selected_range.end;
+                flushed_selection = true;
+            }
+
+            // If the selection intersects the collapsed node, we won't collapse it.
+            if selected_range.end >= mat.collapse.start {
+                continue;
+            }
+        }
+
+        summary.extend(buffer.text_for_range(offset..mat.collapse.start));
+        for keep in mat.keep {
+            summary.extend(buffer.text_for_range(keep));
+        }
+        offset = mat.collapse.end;
+    }
+
+    // Flush selection if we haven't already done so.
+    if !flushed_selection && offset <= selected_range.start {
+        summary.extend(buffer.text_for_range(offset..selected_range.start));
+        summary.push_str("<|START|");
+        if selected_range.end == selected_range.start {
+            summary.push_str(">");
+        } else {
+            summary.extend(buffer.text_for_range(selected_range.clone()));
+            summary.push_str("|END|>");
+        }
+        offset = selected_range.end;
+    }
+
+    summary.extend(buffer.text_for_range(offset..buffer.len()));
+    summary
+}
+
+pub fn generate_content_prompt(
+    user_prompt: String,
+    language_name: Option<&str>,
+    buffer: &BufferSnapshot,
+    range: Range<impl ToOffset>,
+    kind: CodegenKind,
+) -> String {
+    let mut prompt = String::new();
+
+    // General Preamble
+    if let Some(language_name) = language_name {
+        writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
+    } else {
+        writeln!(prompt, "You're an expert engineer.\n").unwrap();
+    }
+
+    let outline = summarize(buffer, range);
+    writeln!(
+        prompt,
+        "The file you are currently working on has the following outline:"
+    )
+    .unwrap();
+    if let Some(language_name) = language_name {
+        let language_name = language_name.to_lowercase();
+        writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
+    } else {
+        writeln!(prompt, "```\n{outline}\n```").unwrap();
+    }
+
+    match kind {
+        CodegenKind::Generate { position: _ } => {
+            writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
+            writeln!(
+                prompt,
+                "Assume the cursor is located where the `<|START|` marker is."
+            )
+            .unwrap();
+            writeln!(
+                prompt,
+                "Text can't be replaced, so assume your answer will be inserted at the cursor."
+            )
+            .unwrap();
+            writeln!(
+                prompt,
+                "Generate text based on the users prompt: {user_prompt}"
+            )
+            .unwrap();
+        }
+        CodegenKind::Transform { range: _ } => {
+            writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+            writeln!(
+                prompt,
+                "Modify the users code selected text based upon the users prompt: {user_prompt}"
+            )
+            .unwrap();
+            writeln!(
+                prompt,
+                "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
+            )
+            .unwrap();
+        }
+    }
+
+    if let Some(language_name) = language_name {
+        writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
+    }
+    writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
+    writeln!(prompt, "Never make remarks about the output.").unwrap();
+
+    prompt
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+
+    use super::*;
+    use std::sync::Arc;
+
+    use gpui::AppContext;
+    use indoc::indoc;
+    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
+    use settings::SettingsStore;
+
+    pub(crate) fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                path_suffixes: vec!["rs".to_string()],
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                [(line_comment) (attribute_item)]* @context
+                .
+                [
+                    (struct_item
+                        name: (_) @name)
+
+                    (enum_item
+                        name: (_) @name)
+
+                    (impl_item
+                        trait: (_)? @name
+                        "for"? @name
+                        type: (_) @name)
+
+                    (trait_item
+                        name: (_) @name)
+
+                    (function_item
+                        name: (_) @name
+                        body: (block
+                            "{" @keep
+                            "}" @keep) @collapse)
+
+                    (macro_definition
+                        name: (_) @name)
+                    ] @item
+                )
+            "#,
+        )
+        .unwrap()
+    }
+
+    #[gpui::test]
+    fn test_outline_for_prompt(cx: &mut AppContext) {
+        cx.set_global(SettingsStore::test(cx));
+        language_settings::init(cx);
+        let text = indoc! {"
+            struct X {
+                a: usize,
+                b: usize,
+            }
+
+            impl X {
+
+                fn new() -> Self {
+                    let a = 1;
+                    let b = 2;
+                    Self { a, b }
+                }
+
+                pub fn a(&self, param: bool) -> usize {
+                    self.a
+                }
+
+                pub fn b(&self) -> usize {
+                    self.b
+                }
+            }
+        "};
+        let buffer =
+            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
+        let snapshot = buffer.read(cx).snapshot();
+
+        assert_eq!(
+            summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
+            indoc! {"
+                struct X {
+                    <|START|>a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
+        );
+
+        assert_eq!(
+            summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {
+                        let <|START|a |END|>= 1;
+                        let b = 2;
+                        Self { a, b }
+                    }
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
+        );
+
+        assert_eq!(
+            summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+                <|START|>
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
+        );
+
+        assert_eq!(
+            summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
+            indoc! {"
+                struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+                <|START|>"}
+        );
+
+        // Ensure nested functions get collapsed properly.
+        let text = indoc! {"
+            struct X {
+                a: usize,
+                b: usize,
+            }
+
+            impl X {
+
+                fn new() -> Self {
+                    let a = 1;
+                    let b = 2;
+                    Self { a, b }
+                }
+
+                pub fn a(&self, param: bool) -> usize {
+                    let a = 30;
+                    fn nested() -> usize {
+                        3
+                    }
+                    self.a + nested()
+                }
+
+                pub fn b(&self) -> usize {
+                    self.b
+                }
+            }
+        "};
+        buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
+        let snapshot = buffer.read(cx).snapshot();
+        assert_eq!(
+            summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
+            indoc! {"
+                <|START|>struct X {
+                    a: usize,
+                    b: usize,
+                }
+
+                impl X {
+
+                    fn new() -> Self {}
+
+                    pub fn a(&self, param: bool) -> usize {}
+
+                    pub fn b(&self) -> usize {}
+                }
+            "}
+        );
+    }
+}

crates/language/src/buffer.rs 🔗

@@ -8,8 +8,8 @@ use crate::{
     language_settings::{language_settings, LanguageSettings},
     outline::OutlineItem,
     syntax_map::{
-        SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxSnapshot,
-        ToTreeSitterPoint,
+        SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches,
+        SyntaxSnapshot, ToTreeSitterPoint,
     },
     CodeLabel, LanguageScope, Outline,
 };
@@ -2467,6 +2467,14 @@ impl BufferSnapshot {
         Some(items)
     }
 
+    pub fn matches(
+        &self,
+        range: Range<usize>,
+        query: fn(&Grammar) -> Option<&tree_sitter::Query>,
+    ) -> SyntaxMapMatches {
+        self.syntax.matches(range, self, query)
+    }
+
     /// Returns bracket range pairs overlapping or adjacent to `range`
     pub fn bracket_ranges<'a, T: ToOffset>(
         &'a self,

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -305,6 +305,11 @@ async fn test_code_context_retrieval_rust() {
                 todo!();
             }
         }
+
+        #[derive(Clone)]
+        struct D {
+            name: String
+        }
     "
     .unindent();
 
@@ -361,6 +366,15 @@ async fn test_code_context_retrieval_rust() {
                 .unindent(),
                 text.find("fn function_2").unwrap(),
             ),
+            (
+                "
+                #[derive(Clone)]
+                struct D {
+                    name: String
+                }"
+                .unindent(),
+                text.find("struct D").unwrap(),
+            ),
         ],
     );
 }
@@ -1422,6 +1436,9 @@ fn rust_lang() -> Arc<Language> {
                         name: (_) @name)
                 ] @item
             )
+
+            (attribute_item) @collapse
+            (use_declaration) @collapse
             "#,
         )
         .unwrap(),

crates/zed/src/languages/rust/embedding.scm 🔗

@@ -2,6 +2,7 @@
     [(line_comment) (attribute_item)]* @context
     .
     [
+
         (struct_item
             name: (_) @name)
 
@@ -26,3 +27,6 @@
             name: (_) @name)
         ] @item
     )
+
+(attribute_item) @collapse
+(use_declaration) @collapse