From fa61c1b9c1751912436dc44508af8aaa475493f2 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 13:03:11 -0400 Subject: [PATCH] add prompt template for generate inline content --- crates/ai/src/templates/base.rs | 5 ++ crates/ai/src/templates/generate.rs | 88 +++++++++++++++++++++++++++++ crates/ai/src/templates/mod.rs | 1 + 3 files changed, 94 insertions(+) create mode 100644 crates/ai/src/templates/generate.rs diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 0bf04f5ed17c607ba115446e455ca1ffd937d5bd..d4882bafc91d4a408558a8eafbf7ce5360132217 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -17,6 +17,7 @@ pub(crate) enum PromptFileType { // TODO: Set this up to manage for defaults well pub struct PromptArguments { pub model: Arc, + pub user_prompt: Option, pub language_name: Option, pub project_name: Option, pub snippets: Vec, @@ -196,6 +197,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -225,6 +227,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -255,6 +258,7 @@ pub(crate) mod tests { reserved_tokens: 0, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -281,6 +285,7 @@ pub(crate) mod tests { reserved_tokens, buffer: None, selected_range: None, + user_prompt: None, }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::Medium, Box::new(TestPromptTemplate {})), diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..d8a1ff6cf142fe8a4a81079ed3ada3c4f803eb75 --- /dev/null +++ b/crates/ai/src/templates/generate.rs @@ -0,0 +1,88 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You MUST reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans), not the entire file.").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs index 886af86e91db4dada1a051f211c19e030c100ec7..0025269a440d1e6ead6a81615a64a3c28da62bb8 100644 --- a/crates/ai/src/templates/mod.rs +++ b/crates/ai/src/templates/mod.rs @@ -1,4 +1,5 @@ pub mod base; pub mod file_context; +pub mod generate; pub mod preamble; pub mod repository_context;