diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 258684db47096bac2d3df33d0289462dbc841214..6c9b14333e34cbf5fd49d8299ba7bd891b607526 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,6 +1,7 @@ pub mod assistant_panel; mod assistant_settings; mod codegen; +mod prompts; mod streaming_diff; use ai::completion::Role; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 42e5fb78979a6b8136c5c60d29e38e064df3435d..b69c12a2a328ed8643315f091be11d764dcdc00d 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,6 +1,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; @@ -273,13 +274,17 @@ impl AssistantPanel { return; }; + let selection = editor.read(cx).selections.newest_anchor().clone(); + if selection.start.excerpt_id() != selection.end.excerpt_id() { + return; + } + let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let provider = Arc::new(OpenAICompletionProvider::new( api_key, cx.background().clone(), )); - let selection = editor.read(cx).selections.newest_anchor().clone(); let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { CodegenKind::Generate { position: selection.start, @@ -541,11 +546,26 @@ impl AssistantPanel { self.inline_prompt_history.pop_front(); } + let codegen = pending_assist.codegen.clone(); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let range = pending_assist.codegen.read(cx).range(); - let selected_text = snapshot.text_for_range(range.clone()).collect::(); + let range = codegen.read(cx).range(); + let start = snapshot.point_to_buffer_offset(range.start); + let end = snapshot.point_to_buffer_offset(range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + self.finish_inline_assist(inline_assist_id, false, cx); + return; + } + } else { + self.finish_inline_assist(inline_assist_id, false, cx); + return; + }; - let language = snapshot.language_at(range.start); + let language = buffer.language_at(range.start); let language_name = if let Some(language) = language.as_ref() { if Arc::ptr_eq(language, &language::PLAIN_TEXT) { None @@ -555,96 +575,13 @@ impl AssistantPanel { } else { None }; - let language_name = language_name.as_deref(); - - let mut prompt = String::new(); - if let Some(language_name) = language_name { - writeln!(prompt, "You're an expert {language_name} engineer.").unwrap(); - } - match pending_assist.codegen.read(cx).kind() { - CodegenKind::Transform { .. } => { - writeln!( - prompt, - "You're currently working inside an editor on this file:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - for chunk in snapshot.text_for_range(Anchor::min()..Anchor::max()) { - write!(prompt, "{chunk}").unwrap(); - } - writeln!(prompt, "```").unwrap(); - - writeln!( - prompt, - "In particular, the user has selected the following text:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - writeln!(prompt, "{selected_text}").unwrap(); - writeln!(prompt, "```").unwrap(); - writeln!(prompt).unwrap(); - writeln!( - prompt, - "Modify the selected text given the user prompt: {user_prompt}" - ) - .unwrap(); - writeln!( - prompt, - "You MUST reply only with the edited selected text, not the entire file." - ) - .unwrap(); - } - CodegenKind::Generate { .. } => { - writeln!( - prompt, - "You're currently working inside an editor on this file:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - for chunk in snapshot.text_for_range(Anchor::min()..range.start) { - write!(prompt, "{chunk}").unwrap(); - } - write!(prompt, "<|>").unwrap(); - for chunk in snapshot.text_for_range(range.start..Anchor::max()) { - write!(prompt, "{chunk}").unwrap(); - } - writeln!(prompt).unwrap(); - writeln!(prompt, "```").unwrap(); - writeln!( - prompt, - "Assume the cursor is located where the `<|>` marker is." - ) - .unwrap(); - writeln!( - prompt, - "Text can't be replaced, so assume your answer will be inserted at the cursor." - ) - .unwrap(); - writeln!( - prompt, - "Complete the text given the user prompt: {user_prompt}" - ) - .unwrap(); - } - } - if let Some(language_name) = language_name { - writeln!(prompt, "Your answer MUST always be valid {language_name}.").unwrap(); - } - writeln!(prompt, "Always wrap your response in a Markdown codeblock.").unwrap(); - writeln!(prompt, "Never make remarks about the output.").unwrap(); + let codegen_kind = codegen.read(cx).kind().clone(); + let user_prompt = user_prompt.to_string(); + let prompt = cx.background().spawn(async move { + let language_name = language_name.as_deref(); + generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind) + }); let mut messages = Vec::new(); let mut model = settings::get::(cx) .default_open_ai_model @@ -660,18 +597,21 @@ impl AssistantPanel { model = conversation.model.clone(); } - messages.push(RequestMessage { - role: Role::User, - content: prompt, - }); - let request = OpenAIRequest { - model: model.full_name().into(), - messages, - stream: true, - }; - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.start(request, cx)); + cx.spawn(|_, mut cx| async move { + let prompt = prompt.await; + + messages.push(RequestMessage { + role: Role::User, + content: prompt, + }); + let request = OpenAIRequest { + model: model.full_name().into(), + messages, + stream: true, + }; + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + }) + .detach(); } fn update_highlights_for_editor( diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs new file mode 100644 index 0000000000000000000000000000000000000000..bf041dff523d57d62cfbc3f312a350ad4766d160 --- /dev/null +++ b/crates/assistant/src/prompts.rs @@ -0,0 +1,404 @@ +use crate::codegen::CodegenKind; +use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; +use std::cmp::{self, Reverse}; +use std::fmt::Write; +use std::ops::Range; + +fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { + #[derive(Debug)] + struct Match { + collapse: Range, + keep: Vec>, + } + + let selected_range = selected_range.to_offset(buffer); + let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| { + Some(&grammar.embedding_config.as_ref()?.query) + }); + let configs = ts_matches + .grammars() + .iter() + .map(|g| g.embedding_config.as_ref().unwrap()) + .collect::>(); + let mut matches = Vec::new(); + while let Some(mat) = ts_matches.peek() { + let config = &configs[mat.grammar_index]; + if let Some(collapse) = mat.captures.iter().find_map(|cap| { + if Some(cap.index) == config.collapse_capture_ix { + Some(cap.node.byte_range()) + } else { + None + } + }) { + let mut keep = Vec::new(); + for capture in mat.captures.iter() { + if Some(capture.index) == config.keep_capture_ix { + keep.push(capture.node.byte_range()); + } else { + continue; + } + } + ts_matches.advance(); + matches.push(Match { collapse, keep }); + } else { + ts_matches.advance(); + } + } + matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end))); + let mut matches = matches.into_iter().peekable(); + + let mut summary = String::new(); + let mut offset = 0; + let mut flushed_selection = false; + while let Some(mat) = matches.next() { + // Keep extending the collapsed range if the next match surrounds + // the current one. + while let Some(next_mat) = matches.peek() { + if mat.collapse.start <= next_mat.collapse.start + && mat.collapse.end >= next_mat.collapse.end + { + matches.next().unwrap(); + } else { + break; + } + } + + if offset > mat.collapse.start { + // Skip collapsed nodes that have already been summarized. + offset = cmp::max(offset, mat.collapse.end); + continue; + } + + if offset <= selected_range.start && selected_range.start <= mat.collapse.end { + if !flushed_selection { + // The collapsed node ends after the selection starts, so we'll flush the selection first. + summary.extend(buffer.text_for_range(offset..selected_range.start)); + summary.push_str("<|START|"); + if selected_range.end == selected_range.start { + summary.push_str(">"); + } else { + summary.extend(buffer.text_for_range(selected_range.clone())); + summary.push_str("|END|>"); + } + offset = selected_range.end; + flushed_selection = true; + } + + // If the selection intersects the collapsed node, we won't collapse it. + if selected_range.end >= mat.collapse.start { + continue; + } + } + + summary.extend(buffer.text_for_range(offset..mat.collapse.start)); + for keep in mat.keep { + summary.extend(buffer.text_for_range(keep)); + } + offset = mat.collapse.end; + } + + // Flush selection if we haven't already done so. + if !flushed_selection && offset <= selected_range.start { + summary.extend(buffer.text_for_range(offset..selected_range.start)); + summary.push_str("<|START|"); + if selected_range.end == selected_range.start { + summary.push_str(">"); + } else { + summary.extend(buffer.text_for_range(selected_range.clone())); + summary.push_str("|END|>"); + } + offset = selected_range.end; + } + + summary.extend(buffer.text_for_range(offset..buffer.len())); + summary +} + +pub fn generate_content_prompt( + user_prompt: String, + language_name: Option<&str>, + buffer: &BufferSnapshot, + range: Range, + kind: CodegenKind, +) -> String { + let mut prompt = String::new(); + + // General Preamble + if let Some(language_name) = language_name { + writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); + } else { + writeln!(prompt, "You're an expert engineer.\n").unwrap(); + } + + let outline = summarize(buffer, range); + writeln!( + prompt, + "The file you are currently working on has the following outline:" + ) + .unwrap(); + if let Some(language_name) = language_name { + let language_name = language_name.to_lowercase(); + writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); + } else { + writeln!(prompt, "```\n{outline}\n```").unwrap(); + } + + match kind { + CodegenKind::Generate { position: _ } => { + writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap(); + writeln!( + prompt, + "Assume the cursor is located where the `<|START|` marker is." + ) + .unwrap(); + writeln!( + prompt, + "Text can't be replaced, so assume your answer will be inserted at the cursor." + ) + .unwrap(); + writeln!( + prompt, + "Generate text based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + CodegenKind::Transform { range: _ } => { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + writeln!( + prompt, + "Modify the users code selected text based upon the users prompt: {user_prompt}" + ) + .unwrap(); + writeln!( + prompt, + "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file." + ) + .unwrap(); + } + } + + if let Some(language_name) = language_name { + writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); + } + writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); + writeln!(prompt, "Never make remarks about the output.").unwrap(); + + prompt +} + +#[cfg(test)] +pub(crate) mod tests { + + use super::*; + use std::sync::Arc; + + use gpui::AppContext; + use indoc::indoc; + use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; + use settings::SettingsStore; + + pub(crate) fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_embedding_query( + r#" + ( + [(line_comment) (attribute_item)]* @context + . + [ + (struct_item + name: (_) @name) + + (enum_item + name: (_) @name) + + (impl_item + trait: (_)? @name + "for"? @name + type: (_) @name) + + (trait_item + name: (_) @name) + + (function_item + name: (_) @name + body: (block + "{" @keep + "}" @keep) @collapse) + + (macro_definition + name: (_) @name) + ] @item + ) + "#, + ) + .unwrap() + } + + #[gpui::test] + fn test_outline_for_prompt(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + language_settings::init(cx); + let text = indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self { + let a = 1; + let b = 2; + Self { a, b } + } + + pub fn a(&self, param: bool) -> usize { + self.a + } + + pub fn b(&self) -> usize { + self.b + } + } + "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + + assert_eq!( + summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)), + indoc! {" + struct X { + <|START|>a: usize, + b: usize, + } + + impl X { + + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} + ); + + assert_eq!( + summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)), + indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self { + let <|START|a |END|>= 1; + let b = 2; + Self { a, b } + } + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} + ); + + assert_eq!( + summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)), + indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + <|START|> + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} + ); + + assert_eq!( + summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)), + indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + <|START|>"} + ); + + // Ensure nested functions get collapsed properly. + let text = indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self { + let a = 1; + let b = 2; + Self { a, b } + } + + pub fn a(&self, param: bool) -> usize { + let a = 30; + fn nested() -> usize { + 3 + } + self.a + nested() + } + + pub fn b(&self) -> usize { + self.b + } + } + "}; + buffer.update(cx, |buffer, cx| buffer.set_text(text, cx)); + let snapshot = buffer.read(cx).snapshot(); + assert_eq!( + summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)), + indoc! {" + <|START|>struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} + ); + } +} diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 38b2842c127f2a6ade0d43787a24a0c76ff13374..27b01543e1e3f04f9914b1da5c530ddd26a555c1 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -8,8 +8,8 @@ use crate::{ language_settings::{language_settings, LanguageSettings}, outline::OutlineItem, syntax_map::{ - SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxSnapshot, - ToTreeSitterPoint, + SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches, + SyntaxSnapshot, ToTreeSitterPoint, }, CodeLabel, LanguageScope, Outline, }; @@ -2467,6 +2467,14 @@ impl BufferSnapshot { Some(items) } + pub fn matches( + &self, + range: Range, + query: fn(&Grammar) -> Option<&tree_sitter::Query>, + ) -> SyntaxMapMatches { + self.syntax.matches(range, self, query) + } + /// Returns bracket range pairs overlapping or adjacent to `range` pub fn bracket_ranges<'a, T: ToOffset>( &'a self, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index f2cae8a55701e910379d672225c19ac18a489897..182010ca8339e9cc8ec1ff06ac31741eb4fb78ae 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -305,6 +305,11 @@ async fn test_code_context_retrieval_rust() { todo!(); } } + + #[derive(Clone)] + struct D { + name: String + } " .unindent(); @@ -361,6 +366,15 @@ async fn test_code_context_retrieval_rust() { .unindent(), text.find("fn function_2").unwrap(), ), + ( + " + #[derive(Clone)] + struct D { + name: String + }" + .unindent(), + text.find("struct D").unwrap(), + ), ], ); } @@ -1422,6 +1436,9 @@ fn rust_lang() -> Arc { name: (_) @name) ] @item ) + + (attribute_item) @collapse + (use_declaration) @collapse "#, ) .unwrap(), diff --git a/crates/zed/src/languages/rust/embedding.scm b/crates/zed/src/languages/rust/embedding.scm index e4218382a9b1ceb7e087b0d9247d5a4e66b77236..286b1d13571ad62964e3f38415fc4cbbb04e4e99 100644 --- a/crates/zed/src/languages/rust/embedding.scm +++ b/crates/zed/src/languages/rust/embedding.scm @@ -2,6 +2,7 @@ [(line_comment) (attribute_item)]* @context . [ + (struct_item name: (_) @name) @@ -26,3 +27,6 @@ name: (_) @name) ] @item ) + +(attribute_item) @collapse +(use_declaration) @collapse