progress on prompt chains

KCaverly created

Change summary

Cargo.lock                                    |   1 
crates/ai/Cargo.toml                          |   1 
crates/ai/src/prompts.rs                      | 149 +++++++++++++++++++++
crates/ai/src/templates.rs                    |  76 ----------
crates/ai/src/templates/base.rs               | 112 +++++++++++++++
crates/ai/src/templates/mod.rs                |   3 
crates/ai/src/templates/preamble.rs           |  34 ++++
crates/ai/src/templates/repository_context.rs |  49 ++++++
8 files changed, 349 insertions(+), 76 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -91,6 +91,7 @@ dependencies = [
  "futures 0.3.28",
  "gpui",
  "isahc",
+ "language",
  "lazy_static",
  "log",
  "matrixmultiply",

crates/ai/Cargo.toml 🔗

@@ -11,6 +11,7 @@ doctest = false
 [dependencies]
 gpui = { path = "../gpui" }
 util = { path = "../util" }
+language = { path = "../language" }
 async-trait.workspace = true
 anyhow.workspace = true
 futures.workspace = true

crates/ai/src/prompts.rs 🔗

@@ -0,0 +1,149 @@
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+use std::{fmt::Write, ops::Range, path::PathBuf};
+
+pub struct PromptCodeSnippet {
+    path: Option<PathBuf>,
+    language_name: Option<String>,
+    content: String,
+}
+
+impl PromptCodeSnippet {
+    pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+        let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+            let snapshot = buffer.snapshot();
+            let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+            let language_name = buffer
+                .language()
+                .and_then(|language| Some(language.name().to_string()));
+
+            let file_path = buffer
+                .file()
+                .and_then(|file| Some(file.path().to_path_buf()));
+
+            (content, language_name, file_path)
+        });
+
+        PromptCodeSnippet {
+            path: file_path,
+            language_name,
+            content,
+        }
+    }
+}
+
+impl ToString for PromptCodeSnippet {
+    fn to_string(&self) -> String {
+        let path = self
+            .path
+            .as_ref()
+            .and_then(|path| Some(path.to_string_lossy().to_string()))
+            .unwrap_or("".to_string());
+        let language_name = self.language_name.clone().unwrap_or("".to_string());
+        let content = self.content.clone();
+
+        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+    }
+}
+
+enum PromptFileType {
+    Text,
+    Code,
+}
+
+#[derive(Default)]
+struct PromptArguments {
+    pub language_name: Option<String>,
+    pub project_name: Option<String>,
+    pub snippets: Vec<PromptCodeSnippet>,
+    pub model_name: String,
+}
+
+impl PromptArguments {
+    pub fn get_file_type(&self) -> PromptFileType {
+        if self
+            .language_name
+            .as_ref()
+            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+            .unwrap_or(true)
+        {
+            PromptFileType::Code
+        } else {
+            PromptFileType::Text
+        }
+    }
+}
+
+trait PromptTemplate {
+    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String;
+}
+
+struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
+        let mut prompt = String::new();
+
+        match args.get_file_type() {
+            PromptFileType::Code => {
+                writeln!(
+                    prompt,
+                    "You are an expert {} engineer.",
+                    args.language_name.unwrap_or("".to_string())
+                )
+                .unwrap();
+            }
+            PromptFileType::Text => {
+                writeln!(prompt, "You are an expert engineer.").unwrap();
+            }
+        }
+
+        if let Some(project_name) = args.project_name {
+            writeln!(
+                prompt,
+                "You are currently working inside the '{project_name}' in Zed the code editor."
+            )
+            .unwrap();
+        }
+
+        prompt
+    }
+}
+
+struct RepositorySnippets {}
+
+impl PromptTemplate for RepositorySnippets {
+    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
+        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
+        let mut prompt = String::new();
+
+        if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) {
+            let default_token_count =
+                tiktoken_rs::model::get_context_size(args.model_name.as_str());
+            let mut remaining_token_count = max_token_length.unwrap_or(default_token_count);
+
+            for snippet in args.snippets {
+                let mut snippet_prompt = template.to_string();
+                let content = snippet.to_string();
+                writeln!(snippet_prompt, "{content}").unwrap();
+
+                let token_count = encoding
+                    .encode_with_special_tokens(snippet_prompt.as_str())
+                    .len();
+                if token_count <= remaining_token_count {
+                    if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+                        writeln!(prompt, "{snippet_prompt}").unwrap();
+                        remaining_token_count -= token_count;
+                        template = "";
+                    }
+                } else {
+                    break;
+                }
+            }
+        }
+
+        prompt
+    }
+}

crates/ai/src/templates.rs 🔗

@@ -1,76 +0,0 @@
-use std::fmt::Write;
-
-pub struct PromptCodeSnippet {
-    path: Option<PathBuf>,
-    language_name: Option<String>,
-    content: String,
-}
-
-enum PromptFileType {
-    Text,
-    Code,
-}
-
-#[derive(Default)]
-struct PromptArguments {
-    pub language_name: Option<String>,
-    pub project_name: Option<String>,
-    pub snippets: Vec<PromptCodeSnippet>,
-}
-
-impl PromptArguments {
-    pub fn get_file_type(&self) -> PromptFileType {
-        if self
-            .language_name
-            .as_ref()
-            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
-            .unwrap_or(true)
-        {
-            PromptFileType::Code
-        } else {
-            PromptFileType::Text
-        }
-    }
-}
-
-trait PromptTemplate {
-    fn generate(args: PromptArguments) -> String;
-}
-
-struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
-    fn generate(args: PromptArguments) -> String {
-        let mut prompt = String::new();
-
-        match args.get_file_type() {
-            PromptFileType::Code => {
-                writeln!(
-                    prompt,
-                    "You are an expert {} engineer.",
-                    args.language_name.unwrap_or("".to_string())
-                )
-                .unwrap();
-            }
-            PromptFileType::Text => {
-                writeln!(prompt, "You are an expert engineer.").unwrap();
-            }
-        }
-
-        if let Some(project_name) = args.project_name {
-            writeln!(
-                prompt,
-                "You are currently working inside the '{project_name}' in Zed the code editor."
-            )
-            .unwrap();
-        }
-
-        prompt
-    }
-}
-
-struct RepositorySnippets {}
-
-impl PromptTemplate for RepositorySnippets {
-    fn generate(args: PromptArguments) -> String {}
-}

crates/ai/src/templates/base.rs 🔗

@@ -0,0 +1,112 @@
+use std::cmp::Reverse;
+
+use crate::templates::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+    Text,
+    Code,
+}
+
+#[derive(Default)]
+pub struct PromptArguments {
+    pub model_name: String,
+    pub language_name: Option<String>,
+    pub project_name: Option<String>,
+    pub snippets: Vec<PromptCodeSnippet>,
+    pub reserved_tokens: usize,
+}
+
+impl PromptArguments {
+    pub(crate) fn get_file_type(&self) -> PromptFileType {
+        if self
+            .language_name
+            .as_ref()
+            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+            .unwrap_or(true)
+        {
+            PromptFileType::Code
+        } else {
+            PromptFileType::Text
+        }
+    }
+}
+
+pub trait PromptTemplate {
+    fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, PartialOrd, Ord)]
+pub enum PromptPriority {
+    Low,
+    Medium,
+    High,
+}
+
+pub struct PromptChain {
+    args: PromptArguments,
+    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+    pub fn new(
+        args: PromptArguments,
+        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+    ) -> Self {
+        // templates.sort_by(|a, b| a.0.cmp(&b.0));
+
+        PromptChain { args, templates }
+    }
+
+    pub fn generate(&self, truncate: bool) -> anyhow::Result<String> {
+        // Argsort based on Prompt Priority
+        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+        println!("{:?}", sorted_indices);
+
+        let mut prompts = Vec::new();
+        for (_, template) in &self.templates {
+            prompts.push(template.generate(&self.args, None));
+        }
+
+        anyhow::Ok(prompts.join("\n"))
+    }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+    use super::*;
+
+    #[test]
+    pub fn test_prompt_chain() {
+        struct TestPromptTemplate {}
+        impl PromptTemplate for TestPromptTemplate {
+            fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+                "This is a test prompt template".to_string()
+            }
+        }
+
+        struct TestLowPriorityTemplate {}
+        impl PromptTemplate for TestLowPriorityTemplate {
+            fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+                "This is a low priority test prompt template".to_string()
+            }
+        }
+
+        let args = PromptArguments {
+            model_name: "gpt-4".to_string(),
+            ..Default::default()
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (PromptPriority::High, Box::new(TestPromptTemplate {})),
+            (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let prompt = chain.generate(false);
+        println!("{:?}", prompt);
+        panic!();
+    }
+}

crates/ai/src/templates/preamble.rs 🔗

@@ -0,0 +1,34 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+    fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+        let mut prompt = String::new();
+
+        match args.get_file_type() {
+            PromptFileType::Code => {
+                writeln!(
+                    prompt,
+                    "You are an expert {} engineer.",
+                    args.language_name.clone().unwrap_or("".to_string())
+                )
+                .unwrap();
+            }
+            PromptFileType::Text => {
+                writeln!(prompt, "You are an expert engineer.").unwrap();
+            }
+        }
+
+        if let Some(project_name) = args.project_name.clone() {
+            writeln!(
+                prompt,
+                "You are currently working inside the '{project_name}' in Zed the code editor."
+            )
+            .unwrap();
+        }
+
+        prompt
+    }
+}

crates/ai/src/templates/repository_context.rs 🔗

@@ -0,0 +1,49 @@
+use std::{ops::Range, path::PathBuf};
+
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+
+pub struct PromptCodeSnippet {
+    path: Option<PathBuf>,
+    language_name: Option<String>,
+    content: String,
+}
+
+impl PromptCodeSnippet {
+    pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+        let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+            let snapshot = buffer.snapshot();
+            let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+            let language_name = buffer
+                .language()
+                .and_then(|language| Some(language.name().to_string()));
+
+            let file_path = buffer
+                .file()
+                .and_then(|file| Some(file.path().to_path_buf()));
+
+            (content, language_name, file_path)
+        });
+
+        PromptCodeSnippet {
+            path: file_path,
+            language_name,
+            content,
+        }
+    }
+}
+
+impl ToString for PromptCodeSnippet {
+    fn to_string(&self) -> String {
+        let path = self
+            .path
+            .as_ref()
+            .and_then(|path| Some(path.to_string_lossy().to_string()))
+            .unwrap_or("".to_string());
+        let language_name = self.language_name.clone().unwrap_or("".to_string());
+        let content = self.content.clone();
+
+        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+    }
+}