From 178a79fc471c541cc6351f491fbf585a551a9bce Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 12:29:10 -0400 Subject: [PATCH] added prompt template for file context without truncation --- crates/ai/src/templates/base.rs | 13 ++++ crates/ai/src/templates/file_context.rs | 85 +++++++++++++++++++++++++ crates/ai/src/templates/mod.rs | 1 + 3 files changed, 99 insertions(+) create mode 100644 crates/ai/src/templates/file_context.rs diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index b5f9da3586f7793e601ca8f5bf7a3158da5949c8..0bf04f5ed17c607ba115446e455ca1ffd937d5bd 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,6 +1,9 @@ use std::cmp::Reverse; +use std::ops::Range; use std::sync::Arc; +use gpui::ModelHandle; +use language::{Anchor, Buffer, BufferSnapshot, ToOffset}; use util::ResultExt; use crate::models::LanguageModel; @@ -18,6 +21,8 @@ pub struct PromptArguments { pub project_name: Option, pub snippets: Vec, pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, } impl PromptArguments { @@ -189,6 +194,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -216,6 +223,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -244,6 +253,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens: 0, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -268,6 +279,8 @@ pub(crate) mod tests { project_name: None, snippets: Vec::new(), reserved_tokens, + buffer: None, + selected_range: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::Medium, Box::new(TestPromptTemplate {})), diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..68bf424db1ddb6c3cd11907688ee5080e8f41c5f --- /dev/null +++ b/crates/ai/src/templates/file_context.rs @@ -0,0 +1,85 @@ +use language::ToOffset; + +use crate::templates::base::PromptArguments; +use crate::templates::base::PromptTemplate; +use std::fmt::Write; + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompt = String::new(); + + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + writeln!(prompt, "```{language_name}").unwrap(); + + if let Some(buffer) = &args.buffer { + let mut content = String::new(); + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + writeln!( + prompt, + "{}", + buffer.text_for_range(0..start).collect::() + ) + .unwrap(); + + if start == end { + writeln!(prompt, "<|START|>").unwrap(); + } else { + writeln!(prompt, "<|START|").unwrap(); + } + + writeln!( + prompt, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + if start != end { + writeln!(prompt, "|END|>").unwrap(); + } + + writeln!( + prompt, + "{}", + buffer.text_for_range(end..buffer.len()).collect::() + ) + .unwrap(); + + writeln!(prompt, "```").unwrap(); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + } + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs index 62cb600eca4fb641265a3937ec5bf8f1e8c2d9c2..886af86e91db4dada1a051f211c19e030c100ec7 100644 --- a/crates/ai/src/templates/mod.rs +++ b/crates/ai/src/templates/mod.rs @@ -1,3 +1,4 @@ pub mod base; +pub mod file_context; pub mod preamble; pub mod repository_context;