diff --git a/Cargo.lock b/Cargo.lock index 6b92170b88ed0417c5a256bd32691d1dba1cacfc..56c9ceda04f62262ea199d2a7424fee053d68284 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422fe8c1eaec7d10bf59cb5ccaa2d65ca3..b24c4e5ece5b02eac003a6c18f186faa1eaef7ef 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 5256a6a6432907dd22c30d6a03e492a46fef77df..f168c157934f6b70be775f7e17e9ba27ef9b3103 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,2 +1,4 @@ pub mod completion; pub mod embedding; +pub mod models; +pub mod templates; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..d0206cc41c526f171fef8521a120f8f4ff70aa74 --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,66 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[..length].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[length..].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs new file mode 100644 index 0000000000000000000000000000000000000000..bda1d6c30e61a9e2fd3808fa45a34cbe041cf2b6 --- /dev/null +++ b/crates/ai/src/templates/base.rs @@ -0,0 +1,350 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::templates::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + #[derive(Clone)] + struct DummyLanguageModel { + capacity: usize, + } + + impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[..length] + .into_iter() + .collect::(), + ) + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[length..] + .into_iter() + .collect::(), + ) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } + } + + let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..1afd61192edc02b153abe8cd00836d67caa42f02 --- /dev/null +++ b/crates/ai/src/templates/file_context.rs @@ -0,0 +1,160 @@ +use anyhow::anyhow; +use language::BufferSnapshot; +use language::ToOffset; + +use crate::models::LanguageModel; +use crate::templates::base::PromptArguments; +use crate::templates::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate_start(&start_window, start_goal_tokens)?; + let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..19334340c8e5d302ef7e07f24eedf820670ea9e3 --- /dev/null +++ b/crates/ai/src/templates/generate.rs @@ -0,0 +1,95 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..0025269a440d1e6ead6a81615a64a3c28da62bb8 --- /dev/null +++ b/crates/ai/src/templates/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs new file mode 100644 index 0000000000000000000000000000000000000000..9eabaaeb97fe4216c6bac44cf4eabfb7c129ecf2 --- /dev/null +++ b/crates/ai/src/templates/preamble.rs @@ -0,0 +1,52 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..a8e7f4b5af7bee4d3f29d70c665965dc7e05ed4b --- /dev/null +++ b/crates/ai/src/templates/repository_context.rs @@ -0,0 +1,94 @@ +use crate::templates::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 65edb1832fcfa629a3f03cad439ccbf87cbdc176..4dd4e2a98315c042d74c7ef6bde78200287ab6ad 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,12 +1,15 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, - prompts::{generate_content_prompt, PromptCodeSnippet}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, + templates::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -609,6 +612,18 @@ impl AssistantPanel { let project = pending_assist.project.clone(); + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -646,7 +661,6 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); let user_prompt = user_prompt.to_string(); let snippets = if retrieve_context { @@ -668,14 +682,7 @@ impl AssistantPanel { let snippets = cx.spawn(|_, cx| async move { let mut snippets = Vec::new(); for result in search_results.await { - snippets.push(PromptCodeSnippet::new(result, &cx)); - - // snippets.push(result.buffer.read_with(&cx, |buffer, _| { - // buffer - // .snapshot() - // .text_for_range(result.range) - // .collect::() - // })); + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); } snippets }); @@ -696,11 +703,11 @@ impl AssistantPanel { generate_content_prompt( user_prompt, language_name, - &buffer, + buffer, range, - codegen_kind, snippets, model_name, + project_name, ) }); @@ -717,7 +724,8 @@ impl AssistantPanel { } cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, @@ -729,6 +737,7 @@ impl AssistantPanel { stream: true, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 18e9e18f7d5f673b608130f970697b6c9b3eb5c8..dffcbc29234d3f24174d1d9a6610045105eae890 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,60 +1,13 @@ -use crate::codegen::CodegenKind; -use gpui::AsyncAppContext; +use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::file_context::FileContext; +use ai::templates::generate::GenerateInlineContent; +use ai::templates::preamble::EngineerPreamble; +use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; -use semantic_index::SearchResult; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; -use std::path::PathBuf; -use tiktoken_rs::ChatCompletionRequestMessage; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = - search_result.buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot - .text_for_range(search_result.range.clone()) - .collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} +use std::sync::Arc; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -170,138 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, + buffer: BufferSnapshot, + range: Range, search_results: Vec, model: &str, -) -> String { - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; - - let mut prompts = Vec::new(); - let range = range.to_offset(buffer); - - // General Preamble - if let Some(language_name) = language_name { - prompts.push(format!("You're an expert {language_name} engineer.\n")); - } else { - prompts.push("You're an expert engineer.\n".to_string()); - } - - // Snippets - let mut snippet_position = prompts.len() - 1; - - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); - - prompts.push("The file you are currently working on has the following content:\n".to_string()); - - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - prompts.push(format!("```{language_name}\n{content}\n```")); + project_name: Option, +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) } else { - prompts.push(format!("```\n{content}\n```")); - } - - match kind { - CodegenKind::Generate { position: _ } => { - prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string()); - prompts - .push("Assume the cursor is located where the `<|START|` marker is.".to_string()); - prompts.push( - "Text can't be replaced, so assume your answer will be inserted at the cursor." - .to_string(), - ); - prompts.push(format!( - "Generate text based on the users prompt: {user_prompt}" - )); - } - CodegenKind::Transform { range: _ } => { - prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string()); - prompts.push(format!( - "Modify the users code selected text based upon the users prompt: '{user_prompt}'" - )); - prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string()); - } - } - - if let Some(language_name) = language_name { - prompts.push(format!( - "Your answer MUST always and only be valid {language_name}" - )); - } - prompts.push("Never make remarks about the output.".to_string()); - prompts.push("Do not return any text, except the generated code.".to_string()); - prompts.push("Always wrap your code in a Markdown block".to_string()); - - let current_messages = [ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(prompts.join("\n")), - function_call: None, - name: None, - }]; - - let mut remaining_token_count = if let Ok(current_token_count) = - tiktoken_rs::num_tokens_from_messages(model, ¤t_messages) - { - let max_token_count = tiktoken_rs::model::get_context_size(model); - let intermediate_token_count = if max_token_count > current_token_count { - max_token_count - current_token_count - } else { - 0 - }; - - if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION { - 0 - } else { - intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION - } - } else { - // If tiktoken fails to count token count, assume we have no space remaining. - 0 + None }; - // TODO: - // - add repository name to snippet - // - add file path - // - add language - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; - - for search_result in search_results { - let mut snippet_prompt = template.to_string(); - let snippet = search_result.to_string(); - writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap(); - - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - prompts.insert(snippet_position, snippet_prompt); - snippet_position += 1; - remaining_token_count -= token_count; - // If you have already added the template to the prompt, remove the template. - template = ""; - } - } else { - break; - } - } - } + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name, + snippets: search_results.clone(), + reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), + }; - prompts.join("\n") + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), + ]; + let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; + + anyhow::Ok(prompt) } #[cfg(test)]