Introduction of PromptTemplate and PromptChains (#3139)

Kyle Caverly created

(This PR was written 100% by the Inline Assistant)

This PR brings in new components into our ai and assistant crates namely
PromptTemplate and PromptChains. They offer a new way to generate
prompts that allow for a more flexible and dynamic approach than before.

Release Notes:
- Introduced PromptTemplate: an abstract base for individual parts of
the prompt.
- Added PromptChains: manage multiple PromptTemplates, sort them based
on priority and regulate the output size based on tokens.
- Provided new PromptArguments structure to encapsulate arguments needed
for PromptTemplate.
    - Extended repository_context to include PromptCodeSnippet.

Change summary

Cargo.lock                                    |   1 
crates/ai/Cargo.toml                          |   1 
crates/ai/src/ai.rs                           |   2 
crates/ai/src/models.rs                       |  66 +++
crates/ai/src/templates/base.rs               | 350 +++++++++++++++++++++
crates/ai/src/templates/file_context.rs       | 160 +++++++++
crates/ai/src/templates/generate.rs           |  95 +++++
crates/ai/src/templates/mod.rs                |   5 
crates/ai/src/templates/preamble.rs           |  52 +++
crates/ai/src/templates/repository_context.rs |  94 +++++
crates/assistant/src/assistant_panel.rs       |  39 +
crates/assistant/src/prompts.rs               | 225 ++----------
12 files changed, 895 insertions(+), 195 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -91,6 +91,7 @@ dependencies = [
  "futures 0.3.28",
  "gpui",
  "isahc",
+ "language",
  "lazy_static",
  "log",
  "matrixmultiply",

crates/ai/Cargo.toml 🔗

@@ -11,6 +11,7 @@ doctest = false
 [dependencies]
 gpui = { path = "../gpui" }
 util = { path = "../util" }
+language = { path = "../language" }
 async-trait.workspace = true
 anyhow.workspace = true
 futures.workspace = true

crates/ai/src/ai.rs 🔗

@@ -1,2 +1,4 @@
 pub mod completion;
 pub mod embedding;
+pub mod models;
+pub mod templates;

crates/ai/src/models.rs 🔗

@@ -0,0 +1,66 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
+    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+    fn capacity(&self) -> anyhow::Result<usize>;
+}
+
+pub struct OpenAILanguageModel {
+    name: String,
+    bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+    pub fn load(model_name: &str) -> Self {
+        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+        OpenAILanguageModel {
+            name: model_name.to_string(),
+            bpe,
+        }
+    }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+    fn name(&self) -> String {
+        self.name.clone()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        if let Some(bpe) = &self.bpe {
+            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                bpe.decode(tokens[..length].to_vec())
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                bpe.decode(tokens[length..].to_vec())
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+    }
+}

crates/ai/src/templates/base.rs 🔗

@@ -0,0 +1,350 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::templates::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+    Text,
+    Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+    pub model: Arc<dyn LanguageModel>,
+    pub user_prompt: Option<String>,
+    pub language_name: Option<String>,
+    pub project_name: Option<String>,
+    pub snippets: Vec<PromptCodeSnippet>,
+    pub reserved_tokens: usize,
+    pub buffer: Option<BufferSnapshot>,
+    pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+    pub(crate) fn get_file_type(&self) -> PromptFileType {
+        if self
+            .language_name
+            .as_ref()
+            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+            .unwrap_or(true)
+        {
+            PromptFileType::Code
+        } else {
+            PromptFileType::Text
+        }
+    }
+}
+
+pub trait PromptTemplate {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+    Mandatory,                // Ignores truncation
+    Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        match (self, other) {
+            (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+            (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+            (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+        }
+    }
+}
+
+pub struct PromptChain {
+    args: PromptArguments,
+    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+    pub fn new(
+        args: PromptArguments,
+        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+    ) -> Self {
+        PromptChain { args, templates }
+    }
+
+    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+        // Argsort based on Prompt Priority
+        let seperator = "\n";
+        let seperator_tokens = self.args.model.count_tokens(seperator)?;
+        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+        // If Truncate
+        let mut tokens_outstanding = if truncate {
+            Some(self.args.model.capacity()? - self.args.reserved_tokens)
+        } else {
+            None
+        };
+
+        let mut prompts = vec!["".to_string(); sorted_indices.len()];
+        for idx in sorted_indices {
+            let (_, template) = &self.templates[idx];
+
+            if let Some((template_prompt, prompt_token_count)) =
+                template.generate(&self.args, tokens_outstanding).log_err()
+            {
+                if template_prompt != "" {
+                    prompts[idx] = template_prompt;
+
+                    if let Some(remaining_tokens) = tokens_outstanding {
+                        let new_tokens = prompt_token_count + seperator_tokens;
+                        tokens_outstanding = if remaining_tokens > new_tokens {
+                            Some(remaining_tokens - new_tokens)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                }
+            }
+        }
+
+        prompts.retain(|x| x != "");
+
+        let full_prompt = prompts.join(seperator);
+        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+        anyhow::Ok((prompts.join(seperator), total_token_count))
+    }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+    use super::*;
+
+    #[test]
+    pub fn test_prompt_chain() {
+        struct TestPromptTemplate {}
+        impl PromptTemplate for TestPromptTemplate {
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content)?;
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(&content, max_token_length)?;
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
+            }
+        }
+
+        struct TestLowPriorityTemplate {}
+        impl PromptTemplate for TestLowPriorityTemplate {
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a low priority test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content)?;
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(&content, max_token_length)?;
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
+            }
+        }
+
+        #[derive(Clone)]
+        struct DummyLanguageModel {
+            capacity: usize,
+        }
+
+        impl LanguageModel for DummyLanguageModel {
+            fn name(&self) -> String {
+                "dummy".to_string()
+            }
+            fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+                anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+            }
+            fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+                anyhow::Ok(
+                    content.chars().collect::<Vec<char>>()[..length]
+                        .into_iter()
+                        .collect::<String>(),
+                )
+            }
+            fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+                anyhow::Ok(
+                    content.chars().collect::<Vec<char>>()[length..]
+                        .into_iter()
+                        .collect::<String>(),
+                )
+            }
+            fn capacity(&self) -> anyhow::Result<usize> {
+                anyhow::Ok(self.capacity)
+            }
+        }
+
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let capacity = 20;
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 2 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+
+        assert_eq!(prompt, "This is a test promp".to_string());
+        assert_eq!(token_count, capacity);
+
+        // Change Ordering of Prompts Based on Priority
+        let capacity = 120;
+        let reserved_tokens = 10;
+        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Mandatory,
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+                .to_string()
+        );
+        assert_eq!(token_count, capacity - reserved_tokens);
+    }
+}

crates/ai/src/templates/file_context.rs 🔗

@@ -0,0 +1,160 @@
+use anyhow::anyhow;
+use language::BufferSnapshot;
+use language::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::templates::base::PromptArguments;
+use crate::templates::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+    buffer: &BufferSnapshot,
+    selected_range: &Option<Range<usize>>,
+    model: Arc<dyn LanguageModel>,
+    max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+    let mut prompt = String::new();
+    let mut truncated = false;
+    if let Some(selected_range) = selected_range {
+        let start = selected_range.start.to_offset(buffer);
+        let end = selected_range.end.to_offset(buffer);
+
+        let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+        let mut selected_window = String::new();
+        if start == end {
+            write!(selected_window, "<|START|>").unwrap();
+        } else {
+            write!(selected_window, "<|START|").unwrap();
+        }
+
+        write!(
+            selected_window,
+            "{}",
+            buffer.text_for_range(start..end).collect::<String>()
+        )
+        .unwrap();
+
+        if start != end {
+            write!(selected_window, "|END|>").unwrap();
+        }
+
+        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+        if let Some(max_token_count) = max_token_count {
+            let selected_tokens = model.count_tokens(&selected_window)?;
+            if selected_tokens > max_token_count {
+                return Err(anyhow!(
+                    "selected range is greater than model context window, truncation not possible"
+                ));
+            };
+
+            let mut remaining_tokens = max_token_count - selected_tokens;
+            let start_window_tokens = model.count_tokens(&start_window)?;
+            let end_window_tokens = model.count_tokens(&end_window)?;
+            let outside_tokens = start_window_tokens + end_window_tokens;
+            if outside_tokens > remaining_tokens {
+                let (start_goal_tokens, end_goal_tokens) =
+                    if start_window_tokens < end_window_tokens {
+                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+                        remaining_tokens -= start_goal_tokens;
+                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    } else {
+                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+                        remaining_tokens -= end_goal_tokens;
+                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    };
+
+                let truncated_start_window =
+                    model.truncate_start(&start_window, start_goal_tokens)?;
+                let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+                writeln!(
+                    prompt,
+                    "{truncated_start_window}{selected_window}{truncated_end_window}"
+                )
+                .unwrap();
+                truncated = true;
+            } else {
+                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+            }
+        } else {
+            // If we dont have a selected range, include entire file.
+            writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+            // Dumb truncation strategy
+            if let Some(max_token_count) = max_token_count {
+                if model.count_tokens(&prompt)? > max_token_count {
+                    truncated = true;
+                    prompt = model.truncate(&prompt, max_token_count)?;
+                }
+            }
+        }
+    }
+
+    let token_count = model.count_tokens(&prompt)?;
+    anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        if let Some(buffer) = &args.buffer {
+            let mut prompt = String::new();
+            // Add Initial Preamble
+            // TODO: Do we want to add the path in here?
+            writeln!(
+                prompt,
+                "The file you are currently working on has the following content:"
+            )
+            .unwrap();
+
+            let language_name = args
+                .language_name
+                .clone()
+                .unwrap_or("".to_string())
+                .to_lowercase();
+
+            let (context, _, truncated) = retrieve_context(
+                buffer,
+                &args.selected_range,
+                args.model.clone(),
+                max_token_length,
+            )?;
+            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+            if truncated {
+                writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+            }
+
+            if let Some(selected_range) = &args.selected_range {
+                let start = selected_range.start.to_offset(buffer);
+                let end = selected_range.end.to_offset(buffer);
+
+                if start == end {
+                    writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+                } else {
+                    writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+                }
+            }
+
+            // Really dumb truncation strategy
+            if let Some(max_tokens) = max_token_length {
+                prompt = args.model.truncate(&prompt, max_tokens)?;
+            }
+
+            let token_count = args.model.count_tokens(&prompt)?;
+            anyhow::Ok((prompt, token_count))
+        } else {
+            Err(anyhow!("no buffer provided to retrieve file context from"))
+        }
+    }
+}

crates/ai/src/templates/generate.rs 🔗

@@ -0,0 +1,95 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+    let mut c = s.chars();
+    match c.next() {
+        None => String::new(),
+        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+    }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        let Some(user_prompt) = &args.user_prompt else {
+            return Err(anyhow!("user prompt not provided"));
+        };
+
+        let file_type = args.get_file_type();
+        let content_type = match &file_type {
+            PromptFileType::Code => "code",
+            PromptFileType::Text => "text",
+        };
+
+        let mut prompt = String::new();
+
+        if let Some(selected_range) = &args.selected_range {
+            if selected_range.start == selected_range.end {
+                writeln!(
+                    prompt,
+                    "Assume the cursor is located where the `<|START|>` span is."
+                )
+                .unwrap();
+                writeln!(
+                    prompt,
+                    "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+                    capitalize(content_type)
+                )
+                .unwrap();
+                writeln!(
+                    prompt,
+                    "Generate {content_type} based on the users prompt: {user_prompt}",
+                )
+                .unwrap();
+            } else {
+                writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+                writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+                writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+            }
+        } else {
+            writeln!(
+                prompt,
+                "Generate {content_type} based on the users prompt: {user_prompt}"
+            )
+            .unwrap();
+        }
+
+        if let Some(language_name) = &args.language_name {
+            writeln!(
+                prompt,
+                "Your answer MUST always and only be valid {}.",
+                language_name
+            )
+            .unwrap();
+        }
+        writeln!(prompt, "Never make remarks about the output.").unwrap();
+        writeln!(
+            prompt,
+            "Do not return anything else, except the generated {content_type}."
+        )
+        .unwrap();
+
+        match file_type {
+            PromptFileType::Code => {
+                writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+            }
+            _ => {}
+        }
+
+        // Really dumb truncation strategy
+        if let Some(max_tokens) = max_token_length {
+            prompt = args.model.truncate(&prompt, max_tokens)?;
+        }
+
+        let token_count = args.model.count_tokens(&prompt)?;
+
+        anyhow::Ok((prompt, token_count))
+    }
+}

crates/ai/src/templates/preamble.rs 🔗

@@ -0,0 +1,52 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        let mut prompts = Vec::new();
+
+        match args.get_file_type() {
+            PromptFileType::Code => {
+                prompts.push(format!(
+                    "You are an expert {}engineer.",
+                    args.language_name.clone().unwrap_or("".to_string()) + " "
+                ));
+            }
+            PromptFileType::Text => {
+                prompts.push("You are an expert engineer.".to_string());
+            }
+        }
+
+        if let Some(project_name) = args.project_name.clone() {
+            prompts.push(format!(
+                "You are currently working inside the '{project_name}' project in code editor Zed."
+            ));
+        }
+
+        if let Some(mut remaining_tokens) = max_token_length {
+            let mut prompt = String::new();
+            let mut total_count = 0;
+            for prompt_piece in prompts {
+                let prompt_token_count =
+                    args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+                if remaining_tokens > prompt_token_count {
+                    writeln!(prompt, "{prompt_piece}").unwrap();
+                    remaining_tokens -= prompt_token_count;
+                    total_count += prompt_token_count;
+                }
+            }
+
+            anyhow::Ok((prompt, total_count))
+        } else {
+            let prompt = prompts.join("\n");
+            let token_count = args.model.count_tokens(&prompt)?;
+            anyhow::Ok((prompt, token_count))
+        }
+    }
+}

crates/ai/src/templates/repository_context.rs 🔗

@@ -0,0 +1,94 @@
+use crate::templates::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+    path: Option<PathBuf>,
+    language_name: Option<String>,
+    content: String,
+}
+
+impl PromptCodeSnippet {
+    pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+        let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+            let snapshot = buffer.snapshot();
+            let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+            let language_name = buffer
+                .language()
+                .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+            let file_path = buffer
+                .file()
+                .and_then(|file| Some(file.path().to_path_buf()));
+
+            (content, language_name, file_path)
+        });
+
+        PromptCodeSnippet {
+            path: file_path,
+            language_name,
+            content,
+        }
+    }
+}
+
+impl ToString for PromptCodeSnippet {
+    fn to_string(&self) -> String {
+        let path = self
+            .path
+            .as_ref()
+            .and_then(|path| Some(path.to_string_lossy().to_string()))
+            .unwrap_or("".to_string());
+        let language_name = self.language_name.clone().unwrap_or("".to_string());
+        let content = self.content.clone();
+
+        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+    }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+        let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+        let mut prompt = String::new();
+
+        let mut remaining_tokens = max_token_length.clone();
+        let seperator_token_length = args.model.count_tokens("\n")?;
+        for snippet in &args.snippets {
+            let mut snippet_prompt = template.to_string();
+            let content = snippet.to_string();
+            writeln!(snippet_prompt, "{content}").unwrap();
+
+            let token_count = args.model.count_tokens(&snippet_prompt)?;
+            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+                if let Some(tokens_left) = remaining_tokens {
+                    if tokens_left >= token_count {
+                        writeln!(prompt, "{snippet_prompt}").unwrap();
+                        remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+                        {
+                            Some(tokens_left - token_count - seperator_token_length)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                } else {
+                    writeln!(prompt, "{snippet_prompt}").unwrap();
+                }
+            }
+        }
+
+        let total_token_count = args.model.count_tokens(&prompt)?;
+        anyhow::Ok((prompt, total_token_count))
+    }
+}

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,12 +1,15 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
     codegen::{self, Codegen, CodegenKind},
-    prompts::{generate_content_prompt, PromptCodeSnippet},
+    prompts::generate_content_prompt,
     MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
     SavedMessage,
 };
-use ai::completion::{
-    stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+    completion::{
+        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+    },
+    templates::repository_context::PromptCodeSnippet,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -609,6 +612,18 @@ impl AssistantPanel {
 
         let project = pending_assist.project.clone();
 
+        let project_name = if let Some(project) = project.upgrade(cx) {
+            Some(
+                project
+                    .read(cx)
+                    .worktree_root_names(cx)
+                    .collect::<Vec<&str>>()
+                    .join("/"),
+            )
+        } else {
+            None
+        };
+
         self.inline_prompt_history
             .retain(|prompt| prompt != user_prompt);
         self.inline_prompt_history.push_back(user_prompt.into());
@@ -646,7 +661,6 @@ impl AssistantPanel {
             None
         };
 
-        let codegen_kind = codegen.read(cx).kind().clone();
         let user_prompt = user_prompt.to_string();
 
         let snippets = if retrieve_context {
@@ -668,14 +682,7 @@ impl AssistantPanel {
             let snippets = cx.spawn(|_, cx| async move {
                 let mut snippets = Vec::new();
                 for result in search_results.await {
-                    snippets.push(PromptCodeSnippet::new(result, &cx));
-
-                    // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
-                    //     buffer
-                    //         .snapshot()
-                    //         .text_for_range(result.range)
-                    //         .collect::<String>()
-                    // }));
+                    snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
                 }
                 snippets
             });
@@ -696,11 +703,11 @@ impl AssistantPanel {
             generate_content_prompt(
                 user_prompt,
                 language_name,
-                &buffer,
+                buffer,
                 range,
-                codegen_kind,
                 snippets,
                 model_name,
+                project_name,
             )
         });
 
@@ -717,7 +724,8 @@ impl AssistantPanel {
         }
 
         cx.spawn(|_, mut cx| async move {
-            let prompt = prompt.await;
+            // I Don't know if we want to return a ? here.
+            let prompt = prompt.await?;
 
             messages.push(RequestMessage {
                 role: Role::User,
@@ -729,6 +737,7 @@ impl AssistantPanel {
                 stream: true,
             };
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+            anyhow::Ok(())
         })
         .detach();
     }

crates/assistant/src/prompts.rs 🔗

@@ -1,60 +1,13 @@
-use crate::codegen::CodegenKind;
-use gpui::AsyncAppContext;
+use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::file_context::FileContext;
+use ai::templates::generate::GenerateInlineContent;
+use ai::templates::preamble::EngineerPreamble;
+use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use semantic_index::SearchResult;
 use std::cmp::{self, Reverse};
-use std::fmt::Write;
 use std::ops::Range;
-use std::path::PathBuf;
-use tiktoken_rs::ChatCompletionRequestMessage;
-
-pub struct PromptCodeSnippet {
-    path: Option<PathBuf>,
-    language_name: Option<String>,
-    content: String,
-}
-
-impl PromptCodeSnippet {
-    pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
-        let (content, language_name, file_path) =
-            search_result.buffer.read_with(cx, |buffer, _| {
-                let snapshot = buffer.snapshot();
-                let content = snapshot
-                    .text_for_range(search_result.range.clone())
-                    .collect::<String>();
-
-                let language_name = buffer
-                    .language()
-                    .and_then(|language| Some(language.name().to_string()));
-
-                let file_path = buffer
-                    .file()
-                    .and_then(|file| Some(file.path().to_path_buf()));
-
-                (content, language_name, file_path)
-            });
-
-        PromptCodeSnippet {
-            path: file_path,
-            language_name,
-            content,
-        }
-    }
-}
-
-impl ToString for PromptCodeSnippet {
-    fn to_string(&self) -> String {
-        let path = self
-            .path
-            .as_ref()
-            .and_then(|path| Some(path.to_string_lossy().to_string()))
-            .unwrap_or("".to_string());
-        let language_name = self.language_name.clone().unwrap_or("".to_string());
-        let content = self.content.clone();
-
-        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
-    }
-}
+use std::sync::Arc;
 
 #[allow(dead_code)]
 fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@@ -170,138 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
 pub fn generate_content_prompt(
     user_prompt: String,
     language_name: Option<&str>,
-    buffer: &BufferSnapshot,
-    range: Range<impl ToOffset>,
-    kind: CodegenKind,
+    buffer: BufferSnapshot,
+    range: Range<usize>,
     search_results: Vec<PromptCodeSnippet>,
     model: &str,
-) -> String {
-    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
-    const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
-
-    let mut prompts = Vec::new();
-    let range = range.to_offset(buffer);
-
-    // General Preamble
-    if let Some(language_name) = language_name {
-        prompts.push(format!("You're an expert {language_name} engineer.\n"));
-    } else {
-        prompts.push("You're an expert engineer.\n".to_string());
-    }
-
-    // Snippets
-    let mut snippet_position = prompts.len() - 1;
-
-    let mut content = String::new();
-    content.extend(buffer.text_for_range(0..range.start));
-    if range.start == range.end {
-        content.push_str("<|START|>");
-    } else {
-        content.push_str("<|START|");
-    }
-    content.extend(buffer.text_for_range(range.clone()));
-    if range.start != range.end {
-        content.push_str("|END|>");
-    }
-    content.extend(buffer.text_for_range(range.end..buffer.len()));
-
-    prompts.push("The file you are currently working on has the following content:\n".to_string());
-
-    if let Some(language_name) = language_name {
-        let language_name = language_name.to_lowercase();
-        prompts.push(format!("```{language_name}\n{content}\n```"));
+    project_name: Option<String>,
+) -> anyhow::Result<String> {
+    // Using new Prompt Templates
+    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
+    let lang_name = if let Some(language_name) = language_name {
+        Some(language_name.to_string())
     } else {
-        prompts.push(format!("```\n{content}\n```"));
-    }
-
-    match kind {
-        CodegenKind::Generate { position: _ } => {
-            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
-            prompts
-                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
-            prompts.push(
-                "Text can't be replaced, so assume your answer will be inserted at the cursor."
-                    .to_string(),
-            );
-            prompts.push(format!(
-                "Generate text based on the users prompt: {user_prompt}"
-            ));
-        }
-        CodegenKind::Transform { range: _ } => {
-            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
-            prompts.push(format!(
-                "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
-            ));
-            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
-        }
-    }
-
-    if let Some(language_name) = language_name {
-        prompts.push(format!(
-            "Your answer MUST always and only be valid {language_name}"
-        ));
-    }
-    prompts.push("Never make remarks about the output.".to_string());
-    prompts.push("Do not return any text, except the generated code.".to_string());
-    prompts.push("Always wrap your code in a Markdown block".to_string());
-
-    let current_messages = [ChatCompletionRequestMessage {
-        role: "user".to_string(),
-        content: Some(prompts.join("\n")),
-        function_call: None,
-        name: None,
-    }];
-
-    let mut remaining_token_count = if let Ok(current_token_count) =
-        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
-    {
-        let max_token_count = tiktoken_rs::model::get_context_size(model);
-        let intermediate_token_count = if max_token_count > current_token_count {
-            max_token_count - current_token_count
-        } else {
-            0
-        };
-
-        if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
-            0
-        } else {
-            intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
-        }
-    } else {
-        // If tiktoken fails to count token count, assume we have no space remaining.
-        0
+        None
     };
 
-    // TODO:
-    //   - add repository name to snippet
-    //   - add file path
-    //   - add language
-    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
-        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
-
-        for search_result in search_results {
-            let mut snippet_prompt = template.to_string();
-            let snippet = search_result.to_string();
-            writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
-
-            let token_count = encoding
-                .encode_with_special_tokens(snippet_prompt.as_str())
-                .len();
-            if token_count <= remaining_token_count {
-                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
-                    prompts.insert(snippet_position, snippet_prompt);
-                    snippet_position += 1;
-                    remaining_token_count -= token_count;
-                    // If you have already added the template to the prompt, remove the template.
-                    template = "";
-                }
-            } else {
-                break;
-            }
-        }
-    }
+    let args = PromptArguments {
+        model: openai_model,
+        language_name: lang_name.clone(),
+        project_name,
+        snippets: search_results.clone(),
+        reserved_tokens: 1000,
+        buffer: Some(buffer),
+        selected_range: Some(range),
+        user_prompt: Some(user_prompt.clone()),
+    };
 
-    prompts.join("\n")
+    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+        (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
+        (
+            PromptPriority::Ordered { order: 1 },
+            Box::new(RepositoryContext {}),
+        ),
+        (
+            PromptPriority::Ordered { order: 0 },
+            Box::new(FileContext {}),
+        ),
+        (
+            PromptPriority::Mandatory,
+            Box::new(GenerateInlineContent {}),
+        ),
+    ];
+    let chain = PromptChain::new(args, templates);
+    let (prompt, _) = chain.generate(true)?;
+
+    anyhow::Ok(prompt)
 }
 
 #[cfg(test)]