implement initial concept of prompt chain

KCaverly created

Change summary

crates/ai/src/templates/base.rs | 229 +++++++++++++++++++++++++++++++---
1 file changed, 208 insertions(+), 21 deletions(-)

Detailed changes

crates/ai/src/templates/base.rs 🔗

@@ -1,15 +1,25 @@
-use std::cmp::Reverse;
+use std::fmt::Write;
+use std::{cmp::Reverse, sync::Arc};
+
+use util::ResultExt;
 
 use crate::templates::repository_context::PromptCodeSnippet;
 
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> usize;
+    fn truncate(&self, content: &str, length: usize) -> String;
+    fn capacity(&self) -> usize;
+}
+
 pub(crate) enum PromptFileType {
     Text,
     Code,
 }
 
-#[derive(Default)]
+// TODO: Set this up to manage for defaults well
 pub struct PromptArguments {
-    pub model_name: String,
+    pub model: Arc<dyn LanguageModel>,
     pub language_name: Option<String>,
     pub project_name: Option<String>,
     pub snippets: Vec<PromptCodeSnippet>,
@@ -32,7 +42,11 @@ impl PromptArguments {
 }
 
 pub trait PromptTemplate {
-    fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String;
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)>;
 }
 
 #[repr(i8)]
@@ -53,24 +67,52 @@ impl PromptChain {
         args: PromptArguments,
         templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
     ) -> Self {
-        // templates.sort_by(|a, b| a.0.cmp(&b.0));
-
         PromptChain { args, templates }
     }
 
-    pub fn generate(&self, truncate: bool) -> anyhow::Result<String> {
+    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
         // Argsort based on Prompt Priority
+        let seperator = "\n";
+        let seperator_tokens = self.args.model.count_tokens(seperator);
         let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
         sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
 
-        println!("{:?}", sorted_indices);
-
         let mut prompts = Vec::new();
-        for (_, template) in &self.templates {
-            prompts.push(template.generate(&self.args, None));
+
+        // If Truncate
+        let mut tokens_outstanding = if truncate {
+            Some(self.args.model.capacity() - self.args.reserved_tokens)
+        } else {
+            None
+        };
+
+        for idx in sorted_indices {
+            let (_, template) = &self.templates[idx];
+            if let Some((template_prompt, prompt_token_count)) =
+                template.generate(&self.args, tokens_outstanding).log_err()
+            {
+                println!(
+                    "GENERATED PROMPT ({:?}): {:?}",
+                    &prompt_token_count, &template_prompt
+                );
+                if template_prompt != "" {
+                    prompts.push(template_prompt);
+
+                    if let Some(remaining_tokens) = tokens_outstanding {
+                        let new_tokens = prompt_token_count + seperator_tokens;
+                        tokens_outstanding = if remaining_tokens > new_tokens {
+                            Some(remaining_tokens - new_tokens)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                }
+            }
         }
 
-        anyhow::Ok(prompts.join("\n"))
+        let full_prompt = prompts.join(seperator);
+        let total_token_count = self.args.model.count_tokens(&full_prompt);
+        anyhow::Ok((prompts.join(seperator), total_token_count))
     }
 }
 
@@ -82,21 +124,81 @@ pub(crate) mod tests {
     pub fn test_prompt_chain() {
         struct TestPromptTemplate {}
         impl PromptTemplate for TestPromptTemplate {
-            fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
-                "This is a test prompt template".to_string()
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content);
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(&content, max_token_length);
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
             }
         }
 
         struct TestLowPriorityTemplate {}
         impl PromptTemplate for TestLowPriorityTemplate {
-            fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
-                "This is a low priority test prompt template".to_string()
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a low priority test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content);
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(&content, max_token_length);
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
             }
         }
 
+        #[derive(Clone)]
+        struct DummyLanguageModel {
+            capacity: usize,
+        }
+
+        impl DummyLanguageModel {
+            fn set_capacity(&mut self, capacity: usize) {
+                self.capacity = capacity
+            }
+        }
+
+        impl LanguageModel for DummyLanguageModel {
+            fn name(&self) -> String {
+                "dummy".to_string()
+            }
+            fn count_tokens(&self, content: &str) -> usize {
+                content.chars().collect::<Vec<char>>().len()
+            }
+            fn truncate(&self, content: &str, length: usize) -> String {
+                content.chars().collect::<Vec<char>>()[..length]
+                    .into_iter()
+                    .collect::<String>()
+            }
+            fn capacity(&self) -> usize {
+                self.capacity
+            }
+        }
+
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
         let args = PromptArguments {
-            model_name: "gpt-4".to_string(),
-            ..Default::default()
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
         };
 
         let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
@@ -105,8 +207,93 @@ pub(crate) mod tests {
         ];
         let chain = PromptChain::new(args, templates);
 
-        let prompt = chain.generate(false);
-        println!("{:?}", prompt);
-        panic!();
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (PromptPriority::High, Box::new(TestPromptTemplate {})),
+            (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let capacity = 20;
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (PromptPriority::High, Box::new(TestPromptTemplate {})),
+            (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})),
+            (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+
+        assert_eq!(prompt, "This is a test promp".to_string());
+        assert_eq!(token_count, capacity);
+
+        // Change Ordering of Prompts Based on Priority
+        let capacity = 120;
+        let reserved_tokens = 10;
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens,
+        };
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (PromptPriority::Medium, Box::new(TestPromptTemplate {})),
+            (PromptPriority::High, Box::new(TestLowPriorityTemplate {})),
+            (PromptPriority::Low, Box::new(TestLowPriorityTemplate {})),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+        println!("TOKEN COUNT: {:?}", token_count);
+
+        assert_eq!(
+            prompt,
+            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+                .to_string()
+        );
+        assert_eq!(token_count, capacity - reserved_tokens);
     }
 }