From 54c63063e483969ee937417ac3c4eb3805f5cdf8 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 26 Sep 2023 16:23:48 -0400 Subject: [PATCH] changed inline assist generate prompt to leverage outline as opposed to entire prior file Co-Authored-by: Antonio --- crates/assistant/src/assistant.rs | 1 + crates/assistant/src/assistant_panel.rs | 123 +++----- crates/assistant/src/prompts.rs | 378 ++++++++++++++++++++++++ 3 files changed, 412 insertions(+), 90 deletions(-) create mode 100644 crates/assistant/src/prompts.rs diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 258684db47096bac2d3df33d0289462dbc841214..6c9b14333e34cbf5fd49d8299ba7bd891b607526 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,6 +1,7 @@ pub mod assistant_panel; mod assistant_settings; mod codegen; +mod prompts; mod streaming_diff; use ai::completion::Role; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 42e5fb78979a6b8136c5c60d29e38e064df3435d..55a1dfe0f6f28e6d745a7d077869a6ae8e9ce8bf 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,6 +1,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; @@ -541,11 +542,31 @@ impl AssistantPanel { self.inline_prompt_history.pop_front(); } - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + let multi_buffer = editor.read(cx).buffer().read(cx); + let multi_buffer_snapshot = multi_buffer.snapshot(cx); + let snapshot = if multi_buffer.all_buffers().len() > 1 { + return; + } else { + multi_buffer + .all_buffers() + .iter() + .next() + .unwrap() + .read(cx) + .snapshot() + }; + let range = pending_assist.codegen.read(cx).range(); - let selected_text = snapshot.text_for_range(range.clone()).collect::(); + let language_range = snapshot.anchor_at( + range.start.to_offset(&multi_buffer_snapshot), + language::Bias::Left, + ) + ..snapshot.anchor_at( + range.end.to_offset(&multi_buffer_snapshot), + language::Bias::Right, + ); - let language = snapshot.language_at(range.start); + let language = snapshot.language_at(language_range.start); let language_name = if let Some(language) = language.as_ref() { if Arc::ptr_eq(language, &language::PLAIN_TEXT) { None @@ -557,93 +578,15 @@ impl AssistantPanel { }; let language_name = language_name.as_deref(); - let mut prompt = String::new(); - if let Some(language_name) = language_name { - writeln!(prompt, "You're an expert {language_name} engineer.").unwrap(); - } - match pending_assist.codegen.read(cx).kind() { - CodegenKind::Transform { .. } => { - writeln!( - prompt, - "You're currently working inside an editor on this file:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - for chunk in snapshot.text_for_range(Anchor::min()..Anchor::max()) { - write!(prompt, "{chunk}").unwrap(); - } - writeln!(prompt, "```").unwrap(); - - writeln!( - prompt, - "In particular, the user has selected the following text:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - writeln!(prompt, "{selected_text}").unwrap(); - writeln!(prompt, "```").unwrap(); - writeln!(prompt).unwrap(); - writeln!( - prompt, - "Modify the selected text given the user prompt: {user_prompt}" - ) - .unwrap(); - writeln!( - prompt, - "You MUST reply only with the edited selected text, not the entire file." - ) - .unwrap(); - } - CodegenKind::Generate { .. } => { - writeln!( - prompt, - "You're currently working inside an editor on this file:" - ) - .unwrap(); - if let Some(language_name) = language_name { - writeln!(prompt, "```{language_name}").unwrap(); - } else { - writeln!(prompt, "```").unwrap(); - } - for chunk in snapshot.text_for_range(Anchor::min()..range.start) { - write!(prompt, "{chunk}").unwrap(); - } - write!(prompt, "<|>").unwrap(); - for chunk in snapshot.text_for_range(range.start..Anchor::max()) { - write!(prompt, "{chunk}").unwrap(); - } - writeln!(prompt).unwrap(); - writeln!(prompt, "```").unwrap(); - writeln!( - prompt, - "Assume the cursor is located where the `<|>` marker is." - ) - .unwrap(); - writeln!( - prompt, - "Text can't be replaced, so assume your answer will be inserted at the cursor." - ) - .unwrap(); - writeln!( - prompt, - "Complete the text given the user prompt: {user_prompt}" - ) - .unwrap(); - } - } - if let Some(language_name) = language_name { - writeln!(prompt, "Your answer MUST always be valid {language_name}.").unwrap(); - } - writeln!(prompt, "Always wrap your response in a Markdown codeblock.").unwrap(); - writeln!(prompt, "Never make remarks about the output.").unwrap(); + let codegen_kind = pending_assist.codegen.read(cx).kind().clone(); + let prompt = generate_content_prompt( + user_prompt.to_string(), + language_name, + &snapshot, + language_range, + cx, + codegen_kind, + ); let mut messages = Vec::new(); let mut model = settings::get::(cx) diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs new file mode 100644 index 0000000000000000000000000000000000000000..272ae9eac686914faf6b2c0c007f59ac9ff9f77d --- /dev/null +++ b/crates/assistant/src/prompts.rs @@ -0,0 +1,378 @@ +use gpui::AppContext; +use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; +use std::cmp; +use std::ops::Range; +use std::{fmt::Write, iter}; + +use crate::codegen::CodegenKind; + +fn outline_for_prompt( + buffer: &BufferSnapshot, + range: Range, + cx: &AppContext, +) -> Option { + let indent = buffer + .language_indent_size_at(0, cx) + .chars() + .collect::(); + let outline = buffer.outline(None)?; + let range = range.to_offset(buffer); + + let mut text = String::new(); + let mut items = outline.items.into_iter().peekable(); + + let mut intersected = false; + let mut intersection_indent = 0; + let mut extended_range = range.clone(); + + while let Some(item) = items.next() { + let item_range = item.range.to_offset(buffer); + if item_range.end < range.start || item_range.start > range.end { + text.extend(iter::repeat(indent.as_str()).take(item.depth)); + text.push_str(&item.text); + text.push('\n'); + } else { + intersected = true; + let is_terminal = items + .peek() + .map_or(true, |next_item| next_item.depth <= item.depth); + if is_terminal { + if item_range.start <= extended_range.start { + extended_range.start = item_range.start; + intersection_indent = item.depth; + } + extended_range.end = cmp::max(extended_range.end, item_range.end); + } else { + let name_start = item_range.start + item.name_ranges.first().unwrap().start; + let name_end = item_range.start + item.name_ranges.last().unwrap().end; + + if range.start > name_end { + text.extend(iter::repeat(indent.as_str()).take(item.depth)); + text.push_str(&item.text); + text.push('\n'); + } else { + if name_start <= extended_range.start { + extended_range.start = item_range.start; + intersection_indent = item.depth; + } + extended_range.end = cmp::max(extended_range.end, name_end); + } + } + } + + if intersected + && items.peek().map_or(true, |next_item| { + next_item.range.start.to_offset(buffer) > range.end + }) + { + intersected = false; + text.extend(iter::repeat(indent.as_str()).take(intersection_indent)); + text.extend(buffer.text_for_range(extended_range.start..range.start)); + text.push_str("<|"); + text.extend(buffer.text_for_range(range.clone())); + text.push_str("|>"); + text.extend(buffer.text_for_range(range.end..extended_range.end)); + text.push('\n'); + } + } + + Some(text) +} + +pub fn generate_content_prompt( + user_prompt: String, + language_name: Option<&str>, + buffer: &BufferSnapshot, + range: Range, + cx: &AppContext, + kind: CodegenKind, +) -> String { + let mut prompt = String::new(); + + // General Preamble + if let Some(language_name) = language_name { + writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); + } else { + writeln!(prompt, "You're an expert engineer.\n").unwrap(); + } + + let outline = outline_for_prompt(buffer, range.clone(), cx); + if let Some(outline) = outline { + writeln!( + prompt, + "The file you are currently working on has the following outline:" + ) + .unwrap(); + if let Some(language_name) = language_name { + let language_name = language_name.to_lowercase(); + writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); + } else { + writeln!(prompt, "```\n{outline}\n```").unwrap(); + } + } + + // Assume for now that we are just generating + if range.clone().start == range.end { + writeln!(prompt, "In particular, the user's cursor is current on the '<||>' span in the above outline, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|' and '|>' spans.").unwrap(); + } + + match kind { + CodegenKind::Generate { position } => { + writeln!( + prompt, + "Assume the cursor is located where the `<|` marker is." + ) + .unwrap(); + writeln!( + prompt, + "Text can't be replaced, so assume your answer will be inserted at the cursor." + ) + .unwrap(); + writeln!( + prompt, + "Generate text based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + CodegenKind::Transform { range } => { + writeln!( + prompt, + "Modify the users code selected text based upon the users prompt: {user_prompt}" + ) + .unwrap(); + writeln!( + prompt, + "You MUST reply with only the adjusted code, not the entire file." + ) + .unwrap(); + } + } + + if let Some(language_name) = language_name { + writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); + } + writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); + writeln!(prompt, "Never make remarks about the output.").unwrap(); + + prompt +} + +#[cfg(test)] +pub(crate) mod tests { + + use super::*; + use std::sync::Arc; + + use gpui::AppContext; + use indoc::indoc; + use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; + use settings::SettingsStore; + + pub(crate) fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + .with_outline_query( + r#" + (struct_item + "struct" @context + name: (_) @name) @item + (enum_item + "enum" @context + name: (_) @name) @item + (enum_variant + name: (_) @name) @item + (field_declaration + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name) @item + (function_item + "fn" @context + name: (_) @name) @item + (mod_item + "mod" @context + name: (_) @name) @item + "#, + ) + .unwrap() + } + + #[gpui::test] + fn test_outline_for_prompt(cx: &mut AppContext) { + cx.set_global(SettingsStore::test(cx)); + language_settings::init(cx); + let text = indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self { + let a = 1; + let b = 2; + Self { a, b } + } + + pub fn a(&self, param: bool) -> usize { + self.a + } + + pub fn b(&self) -> usize { + self.b + } + } + "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + <||>a: usize + b + impl X + fn new + fn a + fn b + "}) + ); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + a + b + impl X + fn new() -> Self { + let <|a |>= 1; + let b = 2; + Self { a, b } + } + fn a + fn b + "}) + ); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + a + b + impl X + <||> + fn new + fn a + fn b + "}) + ); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + a + b + impl X + fn new() -> Self { + let <|a = 1; + let b = 2; + Self { a, b } + } + + pub f|>n a(&self, param: bool) -> usize { + self.a + } + fn b + "}) + ); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + a + b + impl X<| { + + fn new() -> Self { + let a = 1; + let b = 2; + Self { a, b } + } + |> + fn a + fn b + "}) + ); + + let outline = outline_for_prompt( + &snapshot, + snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)), + cx, + ); + assert_eq!( + outline.as_deref(), + Some(indoc! {" + struct X + a + b + impl X + fn new + fn a + pub fn b(&self) -> usize { + <||>self.b + } + "}) + ); + } +}