From 40755961ea0d0f3e252e2248b027fdbf21a2f659 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 16 Oct 2023 11:54:32 -0400 Subject: [PATCH 01/16] added initial template outline --- crates/ai/src/ai.rs | 1 + crates/ai/src/templates.rs | 76 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 crates/ai/src/templates.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 5256a6a6432907dd22c30d6a03e492a46fef77df..04e9e14536c16d80de133940db6723349e8d2371 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,2 +1,3 @@ pub mod completion; pub mod embedding; +pub mod templates; diff --git a/crates/ai/src/templates.rs b/crates/ai/src/templates.rs new file mode 100644 index 0000000000000000000000000000000000000000..d9771ce56964dcc782eb4c3aaa8a5ec6c8a76cd3 --- /dev/null +++ b/crates/ai/src/templates.rs @@ -0,0 +1,76 @@ +use std::fmt::Write; + +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +enum PromptFileType { + Text, + Code, +} + +#[derive(Default)] +struct PromptArguments { + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, +} + +impl PromptArguments { + pub fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +trait PromptTemplate { + fn generate(args: PromptArguments) -> String; +} + +struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate(args: PromptArguments) -> String { + let mut prompt = String::new(); + + match args.get_file_type() { + PromptFileType::Code => { + writeln!( + prompt, + "You are an expert {} engineer.", + args.language_name.unwrap_or("".to_string()) + ) + .unwrap(); + } + PromptFileType::Text => { + writeln!(prompt, "You are an expert engineer.").unwrap(); + } + } + + if let Some(project_name) = args.project_name { + writeln!( + prompt, + "You are currently working inside the '{project_name}' in Zed the code editor." + ) + .unwrap(); + } + + prompt + } +} + +struct RepositorySnippets {} + +impl PromptTemplate for RepositorySnippets { + fn generate(args: PromptArguments) -> String {} +} From 500af6d7754adf1a60f245200271e4dd40d7fb8f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 16 Oct 2023 18:47:10 -0400 Subject: [PATCH 02/16] progress on prompt chains --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/prompts.rs | 149 ++++++++++++++++++ crates/ai/src/templates.rs | 76 --------- crates/ai/src/templates/base.rs | 112 +++++++++++++ crates/ai/src/templates/mod.rs | 3 + crates/ai/src/templates/preamble.rs | 34 ++++ crates/ai/src/templates/repository_context.rs | 49 ++++++ 8 files changed, 349 insertions(+), 76 deletions(-) create mode 100644 crates/ai/src/prompts.rs delete mode 100644 crates/ai/src/templates.rs create mode 100644 crates/ai/src/templates/base.rs create mode 100644 crates/ai/src/templates/mod.rs create mode 100644 crates/ai/src/templates/preamble.rs create mode 100644 crates/ai/src/templates/repository_context.rs diff --git a/Cargo.lock b/Cargo.lock index cd9dee0bda70dd5180b1f59201dd69feeebba1b6..9938c5d2fa328fb9834db781e32f02eb3b39e5fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422fe8c1eaec7d10bf59cb5ccaa2d65ca3..b24c4e5ece5b02eac003a6c18f186faa1eaef7ef 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/prompts.rs b/crates/ai/src/prompts.rs new file mode 100644 index 0000000000000000000000000000000000000000..6d2c0629fa08e2d464adc5bf1d48c44659da8545 --- /dev/null +++ b/crates/ai/src/prompts.rs @@ -0,0 +1,149 @@ +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; +use std::{fmt::Write, ops::Range, path::PathBuf}; + +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +enum PromptFileType { + Text, + Code, +} + +#[derive(Default)] +struct PromptArguments { + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub model_name: String, +} + +impl PromptArguments { + pub fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +trait PromptTemplate { + fn generate(args: PromptArguments, max_token_length: Option) -> String; +} + +struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate(args: PromptArguments, max_token_length: Option) -> String { + let mut prompt = String::new(); + + match args.get_file_type() { + PromptFileType::Code => { + writeln!( + prompt, + "You are an expert {} engineer.", + args.language_name.unwrap_or("".to_string()) + ) + .unwrap(); + } + PromptFileType::Text => { + writeln!(prompt, "You are an expert engineer.").unwrap(); + } + } + + if let Some(project_name) = args.project_name { + writeln!( + prompt, + "You are currently working inside the '{project_name}' in Zed the code editor." + ) + .unwrap(); + } + + prompt + } +} + +struct RepositorySnippets {} + +impl PromptTemplate for RepositorySnippets { + fn generate(args: PromptArguments, max_token_length: Option) -> String { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; + let mut prompt = String::new(); + + if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) { + let default_token_count = + tiktoken_rs::model::get_context_size(args.model_name.as_str()); + let mut remaining_token_count = max_token_length.unwrap_or(default_token_count); + + for snippet in args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = encoding + .encode_with_special_tokens(snippet_prompt.as_str()) + .len(); + if token_count <= remaining_token_count { + if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_token_count -= token_count; + template = ""; + } + } else { + break; + } + } + } + + prompt + } +} diff --git a/crates/ai/src/templates.rs b/crates/ai/src/templates.rs deleted file mode 100644 index d9771ce56964dcc782eb4c3aaa8a5ec6c8a76cd3..0000000000000000000000000000000000000000 --- a/crates/ai/src/templates.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::fmt::Write; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -enum PromptFileType { - Text, - Code, -} - -#[derive(Default)] -struct PromptArguments { - pub language_name: Option, - pub project_name: Option, - pub snippets: Vec, -} - -impl PromptArguments { - pub fn get_file_type(&self) -> PromptFileType { - if self - .language_name - .as_ref() - .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) - .unwrap_or(true) - { - PromptFileType::Code - } else { - PromptFileType::Text - } - } -} - -trait PromptTemplate { - fn generate(args: PromptArguments) -> String; -} - -struct EngineerPreamble {} - -impl PromptTemplate for EngineerPreamble { - fn generate(args: PromptArguments) -> String { - let mut prompt = String::new(); - - match args.get_file_type() { - PromptFileType::Code => { - writeln!( - prompt, - "You are an expert {} engineer.", - args.language_name.unwrap_or("".to_string()) - ) - .unwrap(); - } - PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); - } - } - - if let Some(project_name) = args.project_name { - writeln!( - prompt, - "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); - } - - prompt - } -} - -struct RepositorySnippets {} - -impl PromptTemplate for RepositorySnippets { - fn generate(args: PromptArguments) -> String {} -} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs new file mode 100644 index 0000000000000000000000000000000000000000..3d8479e51253f8aa7f8157104fb9ed2220cfe3f2 --- /dev/null +++ b/crates/ai/src/templates/base.rs @@ -0,0 +1,112 @@ +use std::cmp::Reverse; + +use crate::templates::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +#[derive(Default)] +pub struct PromptArguments { + pub model_name: String, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub enum PromptPriority { + Low, + Medium, + High, +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + // templates.sort_by(|a, b| a.0.cmp(&b.0)); + + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result { + // Argsort based on Prompt Priority + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + println!("{:?}", sorted_indices); + + let mut prompts = Vec::new(); + for (_, template) in &self.templates { + prompts.push(template.generate(&self.args, None)); + } + + anyhow::Ok(prompts.join("\n")) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + "This is a test prompt template".to_string() + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + "This is a low priority test prompt template".to_string() + } + } + + let args = PromptArguments { + model_name: "gpt-4".to_string(), + ..Default::default() + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let prompt = chain.generate(false); + println!("{:?}", prompt); + panic!(); + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..62cb600eca4fb641265a3937ec5bf8f1e8c2d9c2 --- /dev/null +++ b/crates/ai/src/templates/mod.rs @@ -0,0 +1,3 @@ +pub mod base; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs new file mode 100644 index 0000000000000000000000000000000000000000..b1d33f885ea493f9488894154fe262e7ce177edc --- /dev/null +++ b/crates/ai/src/templates/preamble.rs @@ -0,0 +1,34 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { + let mut prompt = String::new(); + + match args.get_file_type() { + PromptFileType::Code => { + writeln!( + prompt, + "You are an expert {} engineer.", + args.language_name.clone().unwrap_or("".to_string()) + ) + .unwrap(); + } + PromptFileType::Text => { + writeln!(prompt, "You are an expert engineer.").unwrap(); + } + } + + if let Some(project_name) = args.project_name.clone() { + writeln!( + prompt, + "You are currently working inside the '{project_name}' in Zed the code editor." + ) + .unwrap(); + } + + prompt + } +} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..f9c2253c654de8da59ffa99ec07a12233b121d01 --- /dev/null +++ b/crates/ai/src/templates/repository_context.rs @@ -0,0 +1,49 @@ +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} From ad92fe49c7deeb098dcd442bc996602630f4f056 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 17 Oct 2023 11:58:45 -0400 Subject: [PATCH 03/16] implement initial concept of prompt chain --- crates/ai/src/templates/base.rs | 229 +++++++++++++++++++++++++++++--- 1 file changed, 208 insertions(+), 21 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 3d8479e51253f8aa7f8157104fb9ed2220cfe3f2..74a4c424ae93b46da34d3f5493f6e2363b31c2f5 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,15 +1,25 @@ -use std::cmp::Reverse; +use std::fmt::Write; +use std::{cmp::Reverse, sync::Arc}; + +use util::ResultExt; use crate::templates::repository_context::PromptCodeSnippet; +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> usize; + fn truncate(&self, content: &str, length: usize) -> String; + fn capacity(&self) -> usize; +} + pub(crate) enum PromptFileType { Text, Code, } -#[derive(Default)] +// TODO: Set this up to manage for defaults well pub struct PromptArguments { - pub model_name: String, + pub model: Arc, pub language_name: Option, pub project_name: Option, pub snippets: Vec, @@ -32,7 +42,11 @@ impl PromptArguments { } pub trait PromptTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String; + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; } #[repr(i8)] @@ -53,24 +67,52 @@ impl PromptChain { args: PromptArguments, templates: Vec<(PromptPriority, Box)>, ) -> Self { - // templates.sort_by(|a, b| a.0.cmp(&b.0)); - PromptChain { args, templates } } - pub fn generate(&self, truncate: bool) -> anyhow::Result { + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator); let mut sorted_indices = (0..self.templates.len()).collect::>(); sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); - println!("{:?}", sorted_indices); - let mut prompts = Vec::new(); - for (_, template) in &self.templates { - prompts.push(template.generate(&self.args, None)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity() - self.args.reserved_tokens) + } else { + None + }; + + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + println!( + "GENERATED PROMPT ({:?}): {:?}", + &prompt_token_count, &template_prompt + ); + if template_prompt != "" { + prompts.push(template_prompt); + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } } - anyhow::Ok(prompts.join("\n")) + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt); + anyhow::Ok((prompts.join(seperator), total_token_count)) } } @@ -82,21 +124,81 @@ pub(crate) mod tests { pub fn test_prompt_chain() { struct TestPromptTemplate {} impl PromptTemplate for TestPromptTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - "This is a test prompt template".to_string() + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content); + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length); + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) } } struct TestLowPriorityTemplate {} impl PromptTemplate for TestLowPriorityTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - "This is a low priority test prompt template".to_string() + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content); + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length); + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) } } + #[derive(Clone)] + struct DummyLanguageModel { + capacity: usize, + } + + impl DummyLanguageModel { + fn set_capacity(&mut self, capacity: usize) { + self.capacity = capacity + } + } + + impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> usize { + content.chars().collect::>().len() + } + fn truncate(&self, content: &str, length: usize) -> String { + content.chars().collect::>()[..length] + .into_iter() + .collect::() + } + fn capacity(&self) -> usize { + self.capacity + } + } + + let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); let args = PromptArguments { - model_name: "gpt-4".to_string(), - ..Default::default() + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -105,8 +207,93 @@ pub(crate) mod tests { ]; let chain = PromptChain::new(args, templates); - let prompt = chain.generate(false); - println!("{:?}", prompt); - panic!(); + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Medium, Box::new(TestPromptTemplate {})), + (PromptPriority::High, Box::new(TestLowPriorityTemplate {})), + (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + println!("TOKEN COUNT: {:?}", token_count); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); } } From a874a09b7e3b30696dad650bc997342fd8a53a61 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 17 Oct 2023 16:21:03 -0400 Subject: [PATCH 04/16] added openai language model tokenizer and LanguageModel trait --- crates/ai/src/ai.rs | 1 + crates/ai/src/models.rs | 49 ++++++++++++++++++++++++++ crates/ai/src/templates/base.rs | 54 ++++++++++++----------------- crates/ai/src/templates/preamble.rs | 42 +++++++++++++++------- 4 files changed, 102 insertions(+), 44 deletions(-) create mode 100644 crates/ai/src/models.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 04e9e14536c16d80de133940db6723349e8d2371..f168c157934f6b70be775f7e17e9ba27ef9b3103 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,3 +1,4 @@ pub mod completion; pub mod embedding; +pub mod models; pub mod templates; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..4fe96d44f33f10ad1e6ee8572a8cceb02fca8fd4 --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,49 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} + +struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: String) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err(); + OpenAILanguageModel { + name: model_name, + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + bpe.decode(tokens[..length].to_vec()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 74a4c424ae93b46da34d3f5493f6e2363b31c2f5..b5f9da3586f7793e601ca8f5bf7a3158da5949c8 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,17 +1,11 @@ -use std::fmt::Write; -use std::{cmp::Reverse, sync::Arc}; +use std::cmp::Reverse; +use std::sync::Arc; use util::ResultExt; +use crate::models::LanguageModel; use crate::templates::repository_context::PromptCodeSnippet; -pub trait LanguageModel { - fn name(&self) -> String; - fn count_tokens(&self, content: &str) -> usize; - fn truncate(&self, content: &str, length: usize) -> String; - fn capacity(&self) -> usize; -} - pub(crate) enum PromptFileType { Text, Code, @@ -73,7 +67,7 @@ impl PromptChain { pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { // Argsort based on Prompt Priority let seperator = "\n"; - let seperator_tokens = self.args.model.count_tokens(seperator); + let seperator_tokens = self.args.model.count_tokens(seperator)?; let mut sorted_indices = (0..self.templates.len()).collect::>(); sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); @@ -81,7 +75,7 @@ impl PromptChain { // If Truncate let mut tokens_outstanding = if truncate { - Some(self.args.model.capacity() - self.args.reserved_tokens) + Some(self.args.model.capacity()? - self.args.reserved_tokens) } else { None }; @@ -111,7 +105,7 @@ impl PromptChain { } let full_prompt = prompts.join(seperator); - let total_token_count = self.args.model.count_tokens(&full_prompt); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; anyhow::Ok((prompts.join(seperator), total_token_count)) } } @@ -131,10 +125,10 @@ pub(crate) mod tests { ) -> anyhow::Result<(String, usize)> { let mut content = "This is a test prompt template".to_string(); - let mut token_count = args.model.count_tokens(&content); + let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length); + content = args.model.truncate(&content, max_token_length)?; token_count = max_token_length; } } @@ -152,10 +146,10 @@ pub(crate) mod tests { ) -> anyhow::Result<(String, usize)> { let mut content = "This is a low priority test prompt template".to_string(); - let mut token_count = args.model.count_tokens(&content); + let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length); + content = args.model.truncate(&content, max_token_length)?; token_count = max_token_length; } } @@ -169,26 +163,22 @@ pub(crate) mod tests { capacity: usize, } - impl DummyLanguageModel { - fn set_capacity(&mut self, capacity: usize) { - self.capacity = capacity - } - } - impl LanguageModel for DummyLanguageModel { fn name(&self) -> String { "dummy".to_string() } - fn count_tokens(&self, content: &str) -> usize { - content.chars().collect::>().len() + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) } - fn truncate(&self, content: &str, length: usize) -> String { - content.chars().collect::>()[..length] - .into_iter() - .collect::() + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[..length] + .into_iter() + .collect::(), + ) } - fn capacity(&self) -> usize { - self.capacity + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) } } @@ -215,7 +205,7 @@ pub(crate) mod tests { .to_string() ); - assert_eq!(model.count_tokens(&prompt), token_count); + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); // Testing with Truncation Off // Should ignore capacity and return all prompts @@ -242,7 +232,7 @@ pub(crate) mod tests { .to_string() ); - assert_eq!(model.count_tokens(&prompt), token_count); + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); // Testing with Truncation Off // Should ignore capacity and return all prompts diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index b1d33f885ea493f9488894154fe262e7ce177edc..f395dbf8beeabde2a703214cc0426900908be990 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -4,31 +4,49 @@ use std::fmt::Write; struct EngineerPreamble {} impl PromptTemplate for EngineerPreamble { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - let mut prompt = String::new(); + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); match args.get_file_type() { PromptFileType::Code => { - writeln!( - prompt, + prompts.push(format!( "You are an expert {} engineer.", args.language_name.clone().unwrap_or("".to_string()) - ) - .unwrap(); + )); } PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); + prompts.push("You are an expert engineer.".to_string()); } } if let Some(project_name) = args.project_name.clone() { - writeln!( - prompt, + prompts.push(format!( "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); + )); } - prompt + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } } } From 02853bbd606dc87a638bd2ca01a5232203069499 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 17 Oct 2023 17:29:07 -0400 Subject: [PATCH 05/16] added prompt template for repository context --- crates/ai/src/models.rs | 8 +- crates/ai/src/prompts.rs | 149 ------------------ crates/ai/src/templates/preamble.rs | 6 +- crates/ai/src/templates/repository_context.rs | 47 +++++- crates/assistant/src/assistant_panel.rs | 22 ++- crates/assistant/src/prompts.rs | 87 ++++------ 6 files changed, 96 insertions(+), 223 deletions(-) delete mode 100644 crates/ai/src/prompts.rs diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index 4fe96d44f33f10ad1e6ee8572a8cceb02fca8fd4..69e73e9b56ec7db7023983a67f5ee994c97c5725 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -9,16 +9,16 @@ pub trait LanguageModel { fn capacity(&self) -> anyhow::Result; } -struct OpenAILanguageModel { +pub struct OpenAILanguageModel { name: String, bpe: Option, } impl OpenAILanguageModel { - pub fn load(model_name: String) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err(); + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); OpenAILanguageModel { - name: model_name, + name: model_name.to_string(), bpe, } } diff --git a/crates/ai/src/prompts.rs b/crates/ai/src/prompts.rs deleted file mode 100644 index 6d2c0629fa08e2d464adc5bf1d48c44659da8545..0000000000000000000000000000000000000000 --- a/crates/ai/src/prompts.rs +++ /dev/null @@ -1,149 +0,0 @@ -use gpui::{AsyncAppContext, ModelHandle}; -use language::{Anchor, Buffer}; -use std::{fmt::Write, ops::Range, path::PathBuf}; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot.text_for_range(range.clone()).collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} - -enum PromptFileType { - Text, - Code, -} - -#[derive(Default)] -struct PromptArguments { - pub language_name: Option, - pub project_name: Option, - pub snippets: Vec, - pub model_name: String, -} - -impl PromptArguments { - pub fn get_file_type(&self) -> PromptFileType { - if self - .language_name - .as_ref() - .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) - .unwrap_or(true) - { - PromptFileType::Code - } else { - PromptFileType::Text - } - } -} - -trait PromptTemplate { - fn generate(args: PromptArguments, max_token_length: Option) -> String; -} - -struct EngineerPreamble {} - -impl PromptTemplate for EngineerPreamble { - fn generate(args: PromptArguments, max_token_length: Option) -> String { - let mut prompt = String::new(); - - match args.get_file_type() { - PromptFileType::Code => { - writeln!( - prompt, - "You are an expert {} engineer.", - args.language_name.unwrap_or("".to_string()) - ) - .unwrap(); - } - PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); - } - } - - if let Some(project_name) = args.project_name { - writeln!( - prompt, - "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); - } - - prompt - } -} - -struct RepositorySnippets {} - -impl PromptTemplate for RepositorySnippets { - fn generate(args: PromptArguments, max_token_length: Option) -> String { - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; - let mut prompt = String::new(); - - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) { - let default_token_count = - tiktoken_rs::model::get_context_size(args.model_name.as_str()); - let mut remaining_token_count = max_token_length.unwrap_or(default_token_count); - - for snippet in args.snippets { - let mut snippet_prompt = template.to_string(); - let content = snippet.to_string(); - writeln!(snippet_prompt, "{content}").unwrap(); - - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - writeln!(prompt, "{snippet_prompt}").unwrap(); - remaining_token_count -= token_count; - template = ""; - } - } else { - break; - } - } - } - - prompt - } -} diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index f395dbf8beeabde2a703214cc0426900908be990..5834fa1b21b2011fbbc82d781493c4e4e523b685 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -1,7 +1,7 @@ use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; use std::fmt::Write; -struct EngineerPreamble {} +pub struct EngineerPreamble {} impl PromptTemplate for EngineerPreamble { fn generate( @@ -14,8 +14,8 @@ impl PromptTemplate for EngineerPreamble { match args.get_file_type() { PromptFileType::Code => { prompts.push(format!( - "You are an expert {} engineer.", - args.language_name.clone().unwrap_or("".to_string()) + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " )); } PromptFileType::Text => { diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs index f9c2253c654de8da59ffa99ec07a12233b121d01..7dd1647c440a228b5fc2c8317fe35e0931d4c1a9 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/templates/repository_context.rs @@ -1,8 +1,11 @@ +use crate::templates::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; use std::{ops::Range, path::PathBuf}; use gpui::{AsyncAppContext, ModelHandle}; use language::{Anchor, Buffer}; +#[derive(Clone)] pub struct PromptCodeSnippet { path: Option, language_name: Option, @@ -17,7 +20,7 @@ impl PromptCodeSnippet { let language_name = buffer .language() - .and_then(|language| Some(language.name().to_string())); + .and_then(|language| Some(language.name().to_string().to_lowercase())); let file_path = buffer .file() @@ -47,3 +50,45 @@ impl ToString for PromptCodeSnippet { format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") } } + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let mut template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index e8edf70498c14e7324073b732cae6b887f3131f9..06de5c135fdf535e4f253ec6f92ef2d449f769a2 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,12 +1,15 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, - prompts::{generate_content_prompt, PromptCodeSnippet}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, + templates::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -668,14 +671,7 @@ impl AssistantPanel { let snippets = cx.spawn(|_, cx| async move { let mut snippets = Vec::new(); for result in search_results.await { - snippets.push(PromptCodeSnippet::new(result, &cx)); - - // snippets.push(result.buffer.read_with(&cx, |buffer, _| { - // buffer - // .snapshot() - // .text_for_range(result.range) - // .collect::() - // })); + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); } snippets }); @@ -717,7 +713,8 @@ impl AssistantPanel { } cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, @@ -729,6 +726,7 @@ impl AssistantPanel { stream: true, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 7aafe75920b351e7244f14858036a4aa9af64f6f..e33a6e4022e87c99b899b0f492d0c25e1514cb4f 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,61 +1,15 @@ use crate::codegen::CodegenKind; -use gpui::AsyncAppContext; +use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::preamble::EngineerPreamble; +use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; -use semantic_index::SearchResult; use std::cmp::{self, Reverse}; use std::fmt::Write; use std::ops::Range; -use std::path::PathBuf; +use std::sync::Arc; use tiktoken_rs::ChatCompletionRequestMessage; -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = - search_result.buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot - .text_for_range(search_result.range.clone()) - .collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} - #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { #[derive(Debug)] @@ -175,7 +129,32 @@ pub fn generate_content_prompt( kind: CodegenKind, search_results: Vec, model: &str, -) -> String { +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) + } else { + None + }; + + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name: None, + snippets: search_results.clone(), + reserved_tokens: 1000, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(EngineerPreamble {})), + (PromptPriority::Low, Box::new(RepositoryContext {})), + ]; + let chain = PromptChain::new(args, templates); + + let prompt = chain.generate(true)?; + println!("{:?}", prompt); + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; @@ -183,7 +162,7 @@ pub fn generate_content_prompt( let range = range.to_offset(buffer); // General Preamble - if let Some(language_name) = language_name { + if let Some(language_name) = language_name.clone() { prompts.push(format!("You're an expert {language_name} engineer.\n")); } else { prompts.push("You're an expert engineer.\n".to_string()); @@ -297,7 +276,7 @@ pub fn generate_content_prompt( } } - prompts.join("\n") + anyhow::Ok(prompts.join("\n")) } #[cfg(test)] From 178a79fc471c541cc6351f491fbf585a551a9bce Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 12:29:10 -0400 Subject: [PATCH 06/16] added prompt template for file context without truncation --- crates/ai/src/templates/base.rs | 13 ++++ crates/ai/src/templates/file_context.rs | 85 +++++++++++++++++++++++++ crates/ai/src/templates/mod.rs | 1 + 3 files changed, 99 insertions(+) create mode 100644 crates/ai/src/templates/file_context.rs diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index b5f9da3586f7793e601ca8f5bf7a3158da5949c8..0bf04f5ed17c607ba115446e455ca1ffd937d5bd 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,6 +1,9 @@ use std::cmp::Reverse; +use std::ops::Range; use std::sync::Arc; +use gpui::ModelHandle; +use language::{Anchor, Buffer, BufferSnapshot, ToOffset}; use util::ResultExt; use crate::models::LanguageModel; @@ -18,6 +21,8 @@ pub struct PromptArguments { pub project_name: Option, pub snippets: Vec, pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, } impl PromptArguments { @@ -189,6 +194,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -216,6 +223,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -244,6 +253,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -268,6 +279,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::Medium, Box::new(TestPromptTemplate {})), diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..68bf424db1ddb6c3cd11907688ee5080e8f41c5f --- /dev/null +++ b/crates/ai/src/templates/file_context.rs @@ -0,0 +1,85 @@ +use language::ToOffset; + +use crate::templates::base::PromptArguments; +use crate::templates::base::PromptTemplate; +use std::fmt::Write; + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompt = String::new(); + + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + writeln!(prompt, "```{language_name}").unwrap(); + + if let Some(buffer) = &args.buffer { + let mut content = String::new(); + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + writeln!( + prompt, + "{}", + buffer.text_for_range(0..start).collect::() + ) + .unwrap(); + + if start == end { + writeln!(prompt, "<|START|>").unwrap(); + } else { + writeln!(prompt, "<|START|").unwrap(); + } + + writeln!( + prompt, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + if start != end { + writeln!(prompt, "|END|>").unwrap(); + } + + writeln!( + prompt, + "{}", + buffer.text_for_range(end..buffer.len()).collect::() + ) + .unwrap(); + + writeln!(prompt, "```").unwrap(); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + } + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs index 62cb600eca4fb641265a3937ec5bf8f1e8c2d9c2..886af86e91db4dada1a051f211c19e030c100ec7 100644 --- a/crates/ai/src/templates/mod.rs +++ b/crates/ai/src/templates/mod.rs @@ -1,3 +1,4 @@ pub mod base; +pub mod file_context; pub mod preamble; pub mod repository_context; From fa61c1b9c1751912436dc44508af8aaa475493f2 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 13:03:11 -0400 Subject: [PATCH 07/16] add prompt template for generate inline content --- crates/ai/src/templates/base.rs | 5 ++ crates/ai/src/templates/generate.rs | 88 +++++++++++++++++++++++++++++ crates/ai/src/templates/mod.rs | 1 + 3 files changed, 94 insertions(+) create mode 100644 crates/ai/src/templates/generate.rs diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 0bf04f5ed17c607ba115446e455ca1ffd937d5bd..d4882bafc91d4a408558a8eafbf7ce5360132217 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -17,6 +17,7 @@ pub(crate) enum PromptFileType { // TODO: Set this up to manage for defaults well pub struct PromptArguments { pub model: Arc, + pub user_prompt: Option, pub language_name: Option, pub project_name: Option, pub snippets: Vec, @@ -196,6 +197,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -225,6 +227,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -255,6 +258,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -281,6 +285,7 @@ pub(crate) mod tests { reserved_tokens, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::Medium, Box::new(TestPromptTemplate {})), diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..d8a1ff6cf142fe8a4a81079ed3ada3c4f803eb75 --- /dev/null +++ b/crates/ai/src/templates/generate.rs @@ -0,0 +1,88 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You MUST reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans), not the entire file.").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs index 886af86e91db4dada1a051f211c19e030c100ec7..0025269a440d1e6ead6a81615a64a3c28da62bb8 100644 --- a/crates/ai/src/templates/mod.rs +++ b/crates/ai/src/templates/mod.rs @@ -1,4 +1,5 @@ pub mod base; pub mod file_context; +pub mod generate; pub mod preamble; pub mod repository_context; From b9bb27512caf402727680fc3ad6926f9006adfce Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 13:10:31 -0400 Subject: [PATCH 08/16] fix template ordering during prompt chain generation --- crates/ai/src/templates/base.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index d4882bafc91d4a408558a8eafbf7ce5360132217..db437a029cd73ec620385362ed83061103d82078 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -77,8 +77,6 @@ impl PromptChain { let mut sorted_indices = (0..self.templates.len()).collect::>(); sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); - let mut prompts = Vec::new(); - // If Truncate let mut tokens_outstanding = if truncate { Some(self.args.model.capacity()? - self.args.reserved_tokens) @@ -86,6 +84,7 @@ impl PromptChain { None }; + let mut prompts = vec!["".to_string(); sorted_indices.len()]; for idx in sorted_indices { let (_, template) = &self.templates[idx]; if let Some((template_prompt, prompt_token_count)) = @@ -96,7 +95,7 @@ impl PromptChain { &prompt_token_count, &template_prompt ); if template_prompt != "" { - prompts.push(template_prompt); + prompts[idx] = template_prompt; if let Some(remaining_tokens) = tokens_outstanding { let new_tokens = prompt_token_count + seperator_tokens; From aa1825681c60176d391ba497a9d28b0e5703fa60 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 14:20:12 -0400 Subject: [PATCH 09/16] update the assistant panel to use new prompt templates --- crates/ai/src/templates/base.rs | 4 - crates/ai/src/templates/file_context.rs | 10 +- crates/ai/src/templates/preamble.rs | 2 +- crates/assistant/src/assistant_panel.rs | 17 ++- crates/assistant/src/prompts.rs | 146 +++--------------------- 5 files changed, 33 insertions(+), 146 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index db437a029cd73ec620385362ed83061103d82078..aaf08d755efb4746192bb75e64f0f7cc7e7a4e83 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -90,10 +90,6 @@ impl PromptChain { if let Some((template_prompt, prompt_token_count)) = template.generate(&self.args, tokens_outstanding).log_err() { - println!( - "GENERATED PROMPT ({:?}): {:?}", - &prompt_token_count, &template_prompt - ); if template_prompt != "" { prompts[idx] = template_prompt; diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 68bf424db1ddb6c3cd11907688ee5080e8f41c5f..6d0630504983fbe90597525ea8f49dd23e0a1036 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -44,22 +44,22 @@ impl PromptTemplate for FileContext { .unwrap(); if start == end { - writeln!(prompt, "<|START|>").unwrap(); + write!(prompt, "<|START|>").unwrap(); } else { - writeln!(prompt, "<|START|").unwrap(); + write!(prompt, "<|START|").unwrap(); } - writeln!( + write!( prompt, "{}", buffer.text_for_range(start..end).collect::() ) .unwrap(); if start != end { - writeln!(prompt, "|END|>").unwrap(); + write!(prompt, "|END|>").unwrap(); } - writeln!( + write!( prompt, "{}", buffer.text_for_range(end..buffer.len()).collect::() diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index 5834fa1b21b2011fbbc82d781493c4e4e523b685..9eabaaeb97fe4216c6bac44cf4eabfb7c129ecf2 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble { if let Some(project_name) = args.project_name.clone() { prompts.push(format!( - "You are currently working inside the '{project_name}' in Zed the code editor." + "You are currently working inside the '{project_name}' project in code editor Zed." )); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 3a0f05379e1d06ce9900bcb2179ed1a347c96f70..4dd4e2a98315c042d74c7ef6bde78200287ab6ad 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -612,6 +612,18 @@ impl AssistantPanel { let project = pending_assist.project.clone(); + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -649,7 +661,6 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); let user_prompt = user_prompt.to_string(); let snippets = if retrieve_context { @@ -692,11 +703,11 @@ impl AssistantPanel { generate_content_prompt( user_prompt, language_name, - &buffer, + buffer, range, - codegen_kind, snippets, model_name, + project_name, ) }); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 333742aa0525afe4d362523868ada9cb187cc363..1457d28fff22c83c29090dcded37aa9a915918bd 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,6 +1,8 @@ use crate::codegen::CodegenKind; use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::file_context::FileContext; +use ai::templates::generate::GenerateInlineContent; use ai::templates::preamble::EngineerPreamble; use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; @@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, + buffer: BufferSnapshot, + range: Range, search_results: Vec, model: &str, + project_name: Option, ) -> anyhow::Result { // Using new Prompt Templates let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); @@ -141,146 +143,24 @@ pub fn generate_content_prompt( let args = PromptArguments { model: openai_model, language_name: lang_name.clone(), - project_name: None, + project_name, snippets: search_results.clone(), reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::High, Box::new(EngineerPreamble {})), (PromptPriority::Low, Box::new(RepositoryContext {})), + (PromptPriority::Medium, Box::new(FileContext {})), + (PromptPriority::High, Box::new(GenerateInlineContent {})), ]; let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; - let prompt = chain.generate(true)?; - println!("{:?}", prompt); - - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; - - let mut prompts = Vec::new(); - let range = range.to_offset(buffer); - - // General Preamble - if let Some(language_name) = language_name.clone() { - prompts.push(format!("You're an expert {language_name} engineer.\n")); - } else { - prompts.push("You're an expert engineer.\n".to_string()); - } - - // Snippets - let mut snippet_position = prompts.len() - 1; - - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); - - prompts.push("The file you are currently working on has the following content:\n".to_string()); - - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - prompts.push(format!("```{language_name}\n{content}\n```")); - } else { - prompts.push(format!("```\n{content}\n```")); - } - - match kind { - CodegenKind::Generate { position: _ } => { - prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string()); - prompts - .push("Assume the cursor is located where the `<|START|` marker is.".to_string()); - prompts.push( - "Text can't be replaced, so assume your answer will be inserted at the cursor." - .to_string(), - ); - prompts.push(format!( - "Generate text based on the users prompt: {user_prompt}" - )); - } - CodegenKind::Transform { range: _ } => { - prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string()); - prompts.push(format!( - "Modify the users code selected text based upon the users prompt: '{user_prompt}'" - )); - prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string()); - } - } - - if let Some(language_name) = language_name { - prompts.push(format!( - "Your answer MUST always and only be valid {language_name}" - )); - } - prompts.push("Never make remarks about the output.".to_string()); - prompts.push("Do not return any text, except the generated code.".to_string()); - prompts.push("Always wrap your code in a Markdown block".to_string()); - - let current_messages = [ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(prompts.join("\n")), - function_call: None, - name: None, - }]; - - let mut remaining_token_count = if let Ok(current_token_count) = - tiktoken_rs::num_tokens_from_messages(model, ¤t_messages) - { - let max_token_count = tiktoken_rs::model::get_context_size(model); - let intermediate_token_count = if max_token_count > current_token_count { - max_token_count - current_token_count - } else { - 0 - }; - - if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION { - 0 - } else { - intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION - } - } else { - // If tiktoken fails to count token count, assume we have no space remaining. - 0 - }; - - // TODO: - // - add repository name to snippet - // - add file path - // - add language - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; - - for search_result in search_results { - let mut snippet_prompt = template.to_string(); - let snippet = search_result.to_string(); - writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap(); - - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - prompts.insert(snippet_position, snippet_prompt); - snippet_position += 1; - remaining_token_count -= token_count; - // If you have already added the template to the prompt, remove the template. - template = ""; - } - } else { - break; - } - } - } - - anyhow::Ok(prompts.join("\n")) + anyhow::Ok(prompt) } #[cfg(test)] From 473067db3173f6e43666f1283c850cff8d2b8cd5 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 15:56:39 -0400 Subject: [PATCH 10/16] update PromptPriority to accomodate for both Mandatory and Ordered prompts --- crates/ai/src/templates/base.rs | 101 ++++++++++++++---- crates/ai/src/templates/file_context.rs | 2 - crates/ai/src/templates/repository_context.rs | 2 +- crates/assistant/src/prompts.rs | 20 ++-- 4 files changed, 96 insertions(+), 29 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index aaf08d755efb4746192bb75e64f0f7cc7e7a4e83..2afcc87ff5dc49072b558fffc4f22da1a34909e9 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,9 +1,9 @@ +use anyhow::anyhow; use std::cmp::Reverse; use std::ops::Range; use std::sync::Arc; -use gpui::ModelHandle; -use language::{Anchor, Buffer, BufferSnapshot, ToOffset}; +use language::BufferSnapshot; use util::ResultExt; use crate::models::LanguageModel; @@ -50,11 +50,21 @@ pub trait PromptTemplate { } #[repr(i8)] -#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[derive(PartialEq, Eq, Ord)] pub enum PromptPriority { - Low, - Medium, - High, + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } } pub struct PromptChain { @@ -86,14 +96,36 @@ impl PromptChain { let mut prompts = vec!["".to_string(); sorted_indices.len()]; for idx in sorted_indices { - let (_, template) = &self.templates[idx]; + let (priority, template) = &self.templates[idx]; + + // If PromptPriority is marked as mandatory, we ignore the tokens outstanding + // However, if a prompt is generated in excess of the available tokens, + // we raise an error outlining that a mandatory prompt has exceeded the available + // balance + let template_tokens = if let Some(template_tokens) = tokens_outstanding { + match priority { + &PromptPriority::Mandatory => None, + _ => Some(template_tokens), + } + } else { + None + }; + if let Some((template_prompt, prompt_token_count)) = - template.generate(&self.args, tokens_outstanding).log_err() + template.generate(&self.args, template_tokens).log_err() { if template_prompt != "" { prompts[idx] = template_prompt; if let Some(remaining_tokens) = tokens_outstanding { + if prompt_token_count > remaining_tokens + && priority == &PromptPriority::Mandatory + { + return Err(anyhow!( + "mandatory template added in excess of model capacity" + )); + } + let new_tokens = prompt_token_count + seperator_tokens; tokens_outstanding = if remaining_tokens > new_tokens { Some(remaining_tokens - new_tokens) @@ -105,6 +137,8 @@ impl PromptChain { } } + prompts.retain(|x| x != ""); + let full_prompt = prompts.join(seperator); let total_token_count = self.args.model.count_tokens(&full_prompt)?; anyhow::Ok((prompts.join(seperator), total_token_count)) @@ -196,8 +230,14 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -226,8 +266,14 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -257,9 +303,18 @@ pub(crate) mod tests { }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(TestPromptTemplate {})), - (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), - (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); @@ -283,14 +338,22 @@ pub(crate) mod tests { user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::Medium, Box::new(TestPromptTemplate {})), - (PromptPriority::High, Box::new(TestLowPriorityTemplate {})), - (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), ]; let chain = PromptChain::new(args, templates); let (prompt, token_count) = chain.generate(true).unwrap(); - println!("TOKEN COUNT: {:?}", token_count); assert_eq!( prompt, diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 6d0630504983fbe90597525ea8f49dd23e0a1036..94b194d9bf7ac4a247d8feb9c8327a50e034cf2a 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -30,8 +30,6 @@ impl PromptTemplate for FileContext { writeln!(prompt, "```{language_name}").unwrap(); if let Some(buffer) = &args.buffer { - let mut content = String::new(); - if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs index 7dd1647c440a228b5fc2c8317fe35e0931d4c1a9..a8e7f4b5af7bee4d3f29d70c665965dc7e05ed4b 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/templates/repository_context.rs @@ -60,7 +60,7 @@ impl PromptTemplate for RepositoryContext { max_token_length: Option, ) -> anyhow::Result<(String, usize)> { const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; let mut prompt = String::new(); let mut remaining_tokens = max_token_length.clone(); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 1457d28fff22c83c29090dcded37aa9a915918bd..dffcbc29234d3f24174d1d9a6610045105eae890 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,4 +1,3 @@ -use crate::codegen::CodegenKind; use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::templates::file_context::FileContext; @@ -7,10 +6,8 @@ use ai::templates::preamble::EngineerPreamble; use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; use std::sync::Arc; -use tiktoken_rs::ChatCompletionRequestMessage; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -152,10 +149,19 @@ pub fn generate_content_prompt( }; let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::High, Box::new(EngineerPreamble {})), - (PromptPriority::Low, Box::new(RepositoryContext {})), - (PromptPriority::Medium, Box::new(FileContext {})), - (PromptPriority::High, Box::new(GenerateInlineContent {})), + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), ]; let chain = PromptChain::new(args, templates); let (prompt, _) = chain.generate(true)?; From 32853c20447d35abbf732441cc2e02cd48587938 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 16:23:53 -0400 Subject: [PATCH 11/16] added initial placeholder for truncation without a valid strategy --- crates/ai/src/templates/file_context.rs | 7 +++++++ crates/ai/src/templates/generate.rs | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 94b194d9bf7ac4a247d8feb9c8327a50e034cf2a..e28f9ccdedb293817c22f54e0b2a12f17a40ac9f 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use language::ToOffset; use crate::templates::base::PromptArguments; @@ -12,6 +13,12 @@ impl PromptTemplate for FileContext { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { + if max_token_length.is_some() { + return Err(anyhow!( + "no truncation strategy established for file_context template" + )); + } + let mut prompt = String::new(); // Add Initial Preamble diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index d8a1ff6cf142fe8a4a81079ed3ada3c4f803eb75..053398e873828a22e35e90e5a12b7137c83b2de0 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -18,6 +18,12 @@ impl PromptTemplate for GenerateInlineContent { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { + if max_token_length.is_some() { + return Err(anyhow!( + "no truncation strategy established for generating inline content template" + )); + } + let Some(user_prompt) = &args.user_prompt else { return Err(anyhow!("user prompt not provided")); }; @@ -83,6 +89,7 @@ impl PromptTemplate for GenerateInlineContent { } let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) } } From a0e01e075d46b05ddc0737065348e57f38952edf Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 16:31:29 -0400 Subject: [PATCH 12/16] fix for error when truncating a length less than the string length --- crates/ai/src/models.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index 69e73e9b56ec7db7023983a67f5ee994c97c5725..0cafb49d94705a25322e3d6540d9a4a87e434c41 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -38,7 +38,11 @@ impl LanguageModel for OpenAILanguageModel { fn truncate(&self, content: &str, length: usize) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); - bpe.decode(tokens[..length].to_vec()) + if tokens.len() > length { + bpe.decode(tokens[..length].to_vec()) + } else { + bpe.decode(tokens) + } } else { Err(anyhow!("bpe for open ai model was not retrieved")) } From f59f2eccd5e7f0706f0a3c5b1db6832d67380708 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 16:32:14 -0400 Subject: [PATCH 13/16] added dumb truncation strategies to file_context and generate --- crates/ai/src/templates/base.rs | 26 ++----------------------- crates/ai/src/templates/file_context.rs | 12 +++++------- crates/ai/src/templates/generate.rs | 11 +++++------ 3 files changed, 12 insertions(+), 37 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 2afcc87ff5dc49072b558fffc4f22da1a34909e9..923e1833c2115953a27044d198497db256287907 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use std::cmp::Reverse; use std::ops::Range; use std::sync::Arc; @@ -96,36 +95,15 @@ impl PromptChain { let mut prompts = vec!["".to_string(); sorted_indices.len()]; for idx in sorted_indices { - let (priority, template) = &self.templates[idx]; - - // If PromptPriority is marked as mandatory, we ignore the tokens outstanding - // However, if a prompt is generated in excess of the available tokens, - // we raise an error outlining that a mandatory prompt has exceeded the available - // balance - let template_tokens = if let Some(template_tokens) = tokens_outstanding { - match priority { - &PromptPriority::Mandatory => None, - _ => Some(template_tokens), - } - } else { - None - }; + let (_, template) = &self.templates[idx]; if let Some((template_prompt, prompt_token_count)) = - template.generate(&self.args, template_tokens).log_err() + template.generate(&self.args, tokens_outstanding).log_err() { if template_prompt != "" { prompts[idx] = template_prompt; if let Some(remaining_tokens) = tokens_outstanding { - if prompt_token_count > remaining_tokens - && priority == &PromptPriority::Mandatory - { - return Err(anyhow!( - "mandatory template added in excess of model capacity" - )); - } - let new_tokens = prompt_token_count + seperator_tokens; tokens_outstanding = if remaining_tokens > new_tokens { Some(remaining_tokens - new_tokens) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index e28f9ccdedb293817c22f54e0b2a12f17a40ac9f..00fe99dd7ffc257339caa8c5198e532b967bee40 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use language::ToOffset; use crate::templates::base::PromptArguments; @@ -13,12 +12,6 @@ impl PromptTemplate for FileContext { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { - if max_token_length.is_some() { - return Err(anyhow!( - "no truncation strategy established for file_context template" - )); - } - let mut prompt = String::new(); // Add Initial Preamble @@ -84,6 +77,11 @@ impl PromptTemplate for FileContext { } } + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + let token_count = args.model.count_tokens(&prompt)?; anyhow::Ok((prompt, token_count)) } diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 053398e873828a22e35e90e5a12b7137c83b2de0..34d874cc4128ccae034a0ecf3beace159bbec1ac 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -18,12 +18,6 @@ impl PromptTemplate for GenerateInlineContent { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { - if max_token_length.is_some() { - return Err(anyhow!( - "no truncation strategy established for generating inline content template" - )); - } - let Some(user_prompt) = &args.user_prompt else { return Err(anyhow!("user prompt not provided")); }; @@ -88,6 +82,11 @@ impl PromptTemplate for GenerateInlineContent { _ => {} } + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + let token_count = args.model.count_tokens(&prompt)?; anyhow::Ok((prompt, token_count)) From 587fd707ba9c19fcef18b6bb0f5507fab79641d9 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 16:40:09 -0400 Subject: [PATCH 14/16] added smarter error handling for file_context prompts without provided buffers --- crates/ai/src/templates/file_context.rs | 49 +++++++++++++------------ 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 00fe99dd7ffc257339caa8c5198e532b967bee40..5a6489a00c89f637f03d5c25a357f82f24accfd3 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use language::ToOffset; use crate::templates::base::PromptArguments; @@ -12,24 +13,23 @@ impl PromptTemplate for FileContext { args: &PromptArguments, max_token_length: Option, ) -> anyhow::Result<(String, usize)> { - let mut prompt = String::new(); - - // Add Initial Preamble - // TODO: Do we want to add the path in here? - writeln!( - prompt, - "The file you are currently working on has the following content:" - ) - .unwrap(); + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); - let language_name = args - .language_name - .clone() - .unwrap_or("".to_string()) - .to_lowercase(); - writeln!(prompt, "```{language_name}").unwrap(); + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + writeln!(prompt, "```{language_name}").unwrap(); - if let Some(buffer) = &args.buffer { if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); @@ -74,15 +74,18 @@ impl PromptTemplate for FileContext { } else { // If we dont have a selected range, include entire file. writeln!(prompt, "{}", &buffer.text()).unwrap(); + writeln!(prompt, "```").unwrap(); } - } - // Really dumb truncation strategy - if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; - } + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } - let token_count = args.model.count_tokens(&prompt)?; - anyhow::Ok((prompt, token_count)) + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } } } From 178a84bcf62641ec39bd185b389a6381db06c4ba Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 17:56:59 -0400 Subject: [PATCH 15/16] progress on smarter truncation strategy for file context --- crates/ai/src/models.rs | 13 +++ crates/ai/src/templates/base.rs | 7 ++ crates/ai/src/templates/file_context.rs | 139 +++++++++++++++++------- crates/assistant/src/prompts.rs | 2 + 4 files changed, 124 insertions(+), 37 deletions(-) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index 0cafb49d94705a25322e3d6540d9a4a87e434c41..d0206cc41c526f171fef8521a120f8f4ff70aa74 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -6,6 +6,7 @@ pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } @@ -47,6 +48,18 @@ impl LanguageModel for OpenAILanguageModel { Err(anyhow!("bpe for open ai model was not retrieved")) } } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[length..].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } fn capacity(&self) -> anyhow::Result { anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) } diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 923e1833c2115953a27044d198497db256287907..bda1d6c30e61a9e2fd3808fa45a34cbe041cf2b6 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -190,6 +190,13 @@ pub(crate) mod tests { .collect::(), ) } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[length..] + .into_iter() + .collect::(), + ) + } fn capacity(&self) -> anyhow::Result { anyhow::Ok(self.capacity) } diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 5a6489a00c89f637f03d5c25a357f82f24accfd3..253d24e4691d52371d2af13e07e160ec3ac6e0f6 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -1,9 +1,103 @@ use anyhow::anyhow; +use language::BufferSnapshot; use language::ToOffset; +use crate::models::LanguageModel; use crate::templates::base::PromptArguments; use crate::templates::base::PromptTemplate; use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate_start(&start_window, start_goal_tokens)?; + let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} pub struct FileContext {} @@ -28,53 +122,24 @@ impl PromptTemplate for FileContext { .clone() .unwrap_or("".to_string()) .to_lowercase(); - writeln!(prompt, "```{language_name}").unwrap(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); - writeln!( - prompt, - "{}", - buffer.text_for_range(0..start).collect::() - ) - .unwrap(); - - if start == end { - write!(prompt, "<|START|>").unwrap(); - } else { - write!(prompt, "<|START|").unwrap(); - } - - write!( - prompt, - "{}", - buffer.text_for_range(start..end).collect::() - ) - .unwrap(); - if start != end { - write!(prompt, "|END|>").unwrap(); - } - - write!( - prompt, - "{}", - buffer.text_for_range(end..buffer.len()).collect::() - ) - .unwrap(); - - writeln!(prompt, "```").unwrap(); - if start == end { writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); } else { writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); } - } else { - // If we dont have a selected range, include entire file. - writeln!(prompt, "{}", &buffer.text()).unwrap(); - writeln!(prompt, "```").unwrap(); } // Really dumb truncation strategy diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index dffcbc29234d3f24174d1d9a6610045105eae890..c7b52a35405deadc3a9319ea77aea34e1989f273 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -166,6 +166,8 @@ pub fn generate_content_prompt( let chain = PromptChain::new(args, templates); let (prompt, _) = chain.generate(true)?; + println!("PROMPT: {:?}", &prompt); + anyhow::Ok(prompt) } From 19c2df4822db4731760547f9fea6fe4e810c1115 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 19 Oct 2023 14:33:52 -0400 Subject: [PATCH 16/16] outlined when truncation is taking place in the prompt --- crates/ai/src/templates/file_context.rs | 4 ++++ crates/ai/src/templates/generate.rs | 3 ++- crates/assistant/src/prompts.rs | 2 -- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 253d24e4691d52371d2af13e07e160ec3ac6e0f6..1afd61192edc02b153abe8cd00836d67caa42f02 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -131,6 +131,10 @@ impl PromptTemplate for FileContext { )?; writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + if let Some(selected_range) = &args.selected_range { let start = selected_range.start.to_offset(buffer); let end = selected_range.end.to_offset(buffer); diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 34d874cc4128ccae034a0ecf3beace159bbec1ac..19334340c8e5d302ef7e07f24eedf820670ea9e3 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -50,7 +50,8 @@ impl PromptTemplate for GenerateInlineContent { .unwrap(); } else { writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); - writeln!(prompt, "You MUST reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans), not the entire file.").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); } } else { writeln!( diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index c7b52a35405deadc3a9319ea77aea34e1989f273..dffcbc29234d3f24174d1d9a6610045105eae890 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -166,8 +166,6 @@ pub fn generate_content_prompt( let chain = PromptChain::new(args, templates); let (prompt, _) = chain.generate(true)?; - println!("PROMPT: {:?}", &prompt); - anyhow::Ok(prompt) }