diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 3d8479e51253f8aa7f8157104fb9ed2220cfe3f2..74a4c424ae93b46da34d3f5493f6e2363b31c2f5 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,15 +1,25 @@ -use std::cmp::Reverse; +use std::fmt::Write; +use std::{cmp::Reverse, sync::Arc}; + +use util::ResultExt; use crate::templates::repository_context::PromptCodeSnippet; +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> usize; + fn truncate(&self, content: &str, length: usize) -> String; + fn capacity(&self) -> usize; +} + pub(crate) enum PromptFileType { Text, Code, } -#[derive(Default)] +// TODO: Set this up to manage for defaults well pub struct PromptArguments { - pub model_name: String, + pub model: Arc, pub language_name: Option, pub project_name: Option, pub snippets: Vec, @@ -32,7 +42,11 @@ impl PromptArguments { } pub trait PromptTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String; + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; } #[repr(i8)] @@ -53,24 +67,52 @@ impl PromptChain { args: PromptArguments, templates: Vec<(PromptPriority, Box)>, ) -> Self { - // templates.sort_by(|a, b| a.0.cmp(&b.0)); - PromptChain { args, templates } } - pub fn generate(&self, truncate: bool) -> anyhow::Result { + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator); let mut sorted_indices = (0..self.templates.len()).collect::>(); sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); - println!("{:?}", sorted_indices); - let mut prompts = Vec::new(); - for (_, template) in &self.templates { - prompts.push(template.generate(&self.args, None)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity() - self.args.reserved_tokens) + } else { + None + }; + + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + println!( + "GENERATED PROMPT ({:?}): {:?}", + &prompt_token_count, &template_prompt + ); + if template_prompt != "" { + prompts.push(template_prompt); + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } } - anyhow::Ok(prompts.join("\n")) + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt); + anyhow::Ok((prompts.join(seperator), total_token_count)) } } @@ -82,21 +124,81 @@ pub(crate) mod tests { pub fn test_prompt_chain() { struct TestPromptTemplate {} impl PromptTemplate for TestPromptTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - "This is a test prompt template".to_string() + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content); + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length); + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) } } struct TestLowPriorityTemplate {} impl PromptTemplate for TestLowPriorityTemplate { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - "This is a low priority test prompt template".to_string() + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content); + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length); + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) } } + #[derive(Clone)] + struct DummyLanguageModel { + capacity: usize, + } + + impl DummyLanguageModel { + fn set_capacity(&mut self, capacity: usize) { + self.capacity = capacity + } + } + + impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> usize { + content.chars().collect::>().len() + } + fn truncate(&self, content: &str, length: usize) -> String { + content.chars().collect::>()[..length] + .into_iter() + .collect::() + } + fn capacity(&self) -> usize { + self.capacity + } + } + + let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); let args = PromptArguments { - model_name: "gpt-4".to_string(), - ..Default::default() + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, }; let templates: Vec<(PromptPriority, Box)> = vec![ @@ -105,8 +207,93 @@ pub(crate) mod tests { ]; let chain = PromptChain::new(args, templates); - let prompt = chain.generate(false); - println!("{:?}", prompt); - panic!(); + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(TestPromptTemplate {})), + (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})), + (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Medium, Box::new(TestPromptTemplate {})), + (PromptPriority::High, Box::new(TestLowPriorityTemplate {})), + (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + println!("TOKEN COUNT: {:?}", token_count); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); } }