Cargo.lock 🔗
@@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28",
"gpui",
"isahc",
+ "language",
"lazy_static",
"log",
"matrixmultiply",
Kyle Caverly created
(This PR was written 100% by the Inline Assistant)
This PR brings in new components into our ai and assistant crates namely
PromptTemplate and PromptChains. They offer a new way to generate
prompts that allow for a more flexible and dynamic approach than before.
Release Notes:
- Introduced PromptTemplate: an abstract base for individual parts of
the prompt.
- Added PromptChains: manage multiple PromptTemplates, sort them based
on priority and regulate the output size based on tokens.
- Provided new PromptArguments structure to encapsulate arguments needed
for PromptTemplate.
- Extended repository_context to include PromptCodeSnippet.
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/ai.rs | 2
crates/ai/src/models.rs | 66 +++
crates/ai/src/templates/base.rs | 350 +++++++++++++++++++++
crates/ai/src/templates/file_context.rs | 160 +++++++++
crates/ai/src/templates/generate.rs | 95 +++++
crates/ai/src/templates/mod.rs | 5
crates/ai/src/templates/preamble.rs | 52 +++
crates/ai/src/templates/repository_context.rs | 94 +++++
crates/assistant/src/assistant_panel.rs | 39 +
crates/assistant/src/prompts.rs | 225 ++----------
12 files changed, 895 insertions(+), 195 deletions(-)
@@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28",
"gpui",
"isahc",
+ "language",
"lazy_static",
"log",
"matrixmultiply",
@@ -11,6 +11,7 @@ doctest = false
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
+language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@@ -1,2 +1,4 @@
pub mod completion;
pub mod embedding;
+pub mod models;
+pub mod templates;
@@ -0,0 +1,66 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
+
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ bpe.decode(tokens[..length].to_vec())
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ bpe.decode(tokens[length..].to_vec())
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,350 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::templates::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+ Text,
+ Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+ pub model: Arc<dyn LanguageModel>,
+ pub user_prompt: Option<String>,
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub reserved_tokens: usize,
+ pub buffer: Option<BufferSnapshot>,
+ pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+ pub(crate) fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+pub trait PromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+ Mandatory, // Ignores truncation
+ Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+ (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+ (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+ (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+ }
+ }
+}
+
+pub struct PromptChain {
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+ pub fn new(
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+ ) -> Self {
+ PromptChain { args, templates }
+ }
+
+ pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+ // Argsort based on Prompt Priority
+ let seperator = "\n";
+ let seperator_tokens = self.args.model.count_tokens(seperator)?;
+ let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+ sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+ // If Truncate
+ let mut tokens_outstanding = if truncate {
+ Some(self.args.model.capacity()? - self.args.reserved_tokens)
+ } else {
+ None
+ };
+
+ let mut prompts = vec!["".to_string(); sorted_indices.len()];
+ for idx in sorted_indices {
+ let (_, template) = &self.templates[idx];
+
+ if let Some((template_prompt, prompt_token_count)) =
+ template.generate(&self.args, tokens_outstanding).log_err()
+ {
+ if template_prompt != "" {
+ prompts[idx] = template_prompt;
+
+ if let Some(remaining_tokens) = tokens_outstanding {
+ let new_tokens = prompt_token_count + seperator_tokens;
+ tokens_outstanding = if remaining_tokens > new_tokens {
+ Some(remaining_tokens - new_tokens)
+ } else {
+ Some(0)
+ };
+ }
+ }
+ }
+ }
+
+ prompts.retain(|x| x != "");
+
+ let full_prompt = prompts.join(seperator);
+ let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+ anyhow::Ok((prompts.join(seperator), total_token_count))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+
+ #[test]
+ pub fn test_prompt_chain() {
+ struct TestPromptTemplate {}
+ impl PromptTemplate for TestPromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(&content, max_token_length)?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ struct TestLowPriorityTemplate {}
+ impl PromptTemplate for TestLowPriorityTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a low priority test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(&content, max_token_length)?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ #[derive(Clone)]
+ struct DummyLanguageModel {
+ capacity: usize,
+ }
+
+ impl LanguageModel for DummyLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ anyhow::Ok(
+ content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ )
+ }
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ anyhow::Ok(
+ content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ )
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+ }
+
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let capacity = 20;
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 2 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(prompt, "This is a test promp".to_string());
+ assert_eq!(token_count, capacity);
+
+ // Change Ordering of Prompts Based on Priority
+ let capacity = 120;
+ let reserved_tokens = 10;
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Mandatory,
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+ .to_string()
+ );
+ assert_eq!(token_count, capacity - reserved_tokens);
+ }
+}
@@ -0,0 +1,160 @@
+use anyhow::anyhow;
+use language::BufferSnapshot;
+use language::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::templates::base::PromptArguments;
+use crate::templates::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+ buffer: &BufferSnapshot,
+ selected_range: &Option<Range<usize>>,
+ model: Arc<dyn LanguageModel>,
+ max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+ let mut prompt = String::new();
+ let mut truncated = false;
+ if let Some(selected_range) = selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+ let mut selected_window = String::new();
+ if start == end {
+ write!(selected_window, "<|START|>").unwrap();
+ } else {
+ write!(selected_window, "<|START|").unwrap();
+ }
+
+ write!(
+ selected_window,
+ "{}",
+ buffer.text_for_range(start..end).collect::<String>()
+ )
+ .unwrap();
+
+ if start != end {
+ write!(selected_window, "|END|>").unwrap();
+ }
+
+ let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+ if let Some(max_token_count) = max_token_count {
+ let selected_tokens = model.count_tokens(&selected_window)?;
+ if selected_tokens > max_token_count {
+ return Err(anyhow!(
+ "selected range is greater than model context window, truncation not possible"
+ ));
+ };
+
+ let mut remaining_tokens = max_token_count - selected_tokens;
+ let start_window_tokens = model.count_tokens(&start_window)?;
+ let end_window_tokens = model.count_tokens(&end_window)?;
+ let outside_tokens = start_window_tokens + end_window_tokens;
+ if outside_tokens > remaining_tokens {
+ let (start_goal_tokens, end_goal_tokens) =
+ if start_window_tokens < end_window_tokens {
+ let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+ remaining_tokens -= start_goal_tokens;
+ let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ } else {
+ let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+ remaining_tokens -= end_goal_tokens;
+ let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ };
+
+ let truncated_start_window =
+ model.truncate_start(&start_window, start_goal_tokens)?;
+ let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ writeln!(
+ prompt,
+ "{truncated_start_window}{selected_window}{truncated_end_window}"
+ )
+ .unwrap();
+ truncated = true;
+ } else {
+ writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+ }
+ } else {
+ // If we dont have a selected range, include entire file.
+ writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+ // Dumb truncation strategy
+ if let Some(max_token_count) = max_token_count {
+ if model.count_tokens(&prompt)? > max_token_count {
+ truncated = true;
+ prompt = model.truncate(&prompt, max_token_count)?;
+ }
+ }
+ }
+ }
+
+ let token_count = model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ if let Some(buffer) = &args.buffer {
+ let mut prompt = String::new();
+ // Add Initial Preamble
+ // TODO: Do we want to add the path in here?
+ writeln!(
+ prompt,
+ "The file you are currently working on has the following content:"
+ )
+ .unwrap();
+
+ let language_name = args
+ .language_name
+ .clone()
+ .unwrap_or("".to_string())
+ .to_lowercase();
+
+ let (context, _, truncated) = retrieve_context(
+ buffer,
+ &args.selected_range,
+ args.model.clone(),
+ max_token_length,
+ )?;
+ writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+ if truncated {
+ writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+ }
+
+ if let Some(selected_range) = &args.selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ if start == end {
+ writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+ } else {
+ writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+ }
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(&prompt, max_tokens)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ } else {
+ Err(anyhow!("no buffer provided to retrieve file context from"))
+ }
+ }
+}
@@ -0,0 +1,95 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+ let mut c = s.chars();
+ match c.next() {
+ None => String::new(),
+ Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+ }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let Some(user_prompt) = &args.user_prompt else {
+ return Err(anyhow!("user prompt not provided"));
+ };
+
+ let file_type = args.get_file_type();
+ let content_type = match &file_type {
+ PromptFileType::Code => "code",
+ PromptFileType::Text => "text",
+ };
+
+ let mut prompt = String::new();
+
+ if let Some(selected_range) = &args.selected_range {
+ if selected_range.start == selected_range.end {
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|START|>` span is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+ capitalize(content_type)
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}",
+ )
+ .unwrap();
+ } else {
+ writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+ writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+ writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+ }
+ } else {
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}"
+ )
+ .unwrap();
+ }
+
+ if let Some(language_name) = &args.language_name {
+ writeln!(
+ prompt,
+ "Your answer MUST always and only be valid {}.",
+ language_name
+ )
+ .unwrap();
+ }
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+ writeln!(
+ prompt,
+ "Do not return anything else, except the generated {content_type}."
+ )
+ .unwrap();
+
+ match file_type {
+ PromptFileType::Code => {
+ writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+ }
+ _ => {}
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(&prompt, max_tokens)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+
+ anyhow::Ok((prompt, token_count))
+ }
+}
@@ -0,0 +1,5 @@
+pub mod base;
+pub mod file_context;
+pub mod generate;
+pub mod preamble;
+pub mod repository_context;
@@ -0,0 +1,52 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut prompts = Vec::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ prompts.push(format!(
+ "You are an expert {}engineer.",
+ args.language_name.clone().unwrap_or("".to_string()) + " "
+ ));
+ }
+ PromptFileType::Text => {
+ prompts.push("You are an expert engineer.".to_string());
+ }
+ }
+
+ if let Some(project_name) = args.project_name.clone() {
+ prompts.push(format!(
+ "You are currently working inside the '{project_name}' project in code editor Zed."
+ ));
+ }
+
+ if let Some(mut remaining_tokens) = max_token_length {
+ let mut prompt = String::new();
+ let mut total_count = 0;
+ for prompt_piece in prompts {
+ let prompt_token_count =
+ args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+ if remaining_tokens > prompt_token_count {
+ writeln!(prompt, "{prompt_piece}").unwrap();
+ remaining_tokens -= prompt_token_count;
+ total_count += prompt_token_count;
+ }
+ }
+
+ anyhow::Ok((prompt, total_count))
+ } else {
+ let prompt = prompts.join("\n");
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ }
+ }
+}
@@ -0,0 +1,94 @@
+use crate::templates::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+ let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ });
+
+ PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ }
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+ let mut prompt = String::new();
+
+ let mut remaining_tokens = max_token_length.clone();
+ let seperator_token_length = args.model.count_tokens("\n")?;
+ for snippet in &args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = args.model.count_tokens(&snippet_prompt)?;
+ if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+ if let Some(tokens_left) = remaining_tokens {
+ if tokens_left >= token_count {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+ {
+ Some(tokens_left - token_count - seperator_token_length)
+ } else {
+ Some(0)
+ };
+ }
+ } else {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ }
+ }
+ }
+
+ let total_token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, total_token_count))
+ }
+}
@@ -1,12 +1,15 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
- prompts::{generate_content_prompt, PromptCodeSnippet},
+ prompts::generate_content_prompt,
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
-use ai::completion::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+ completion::{
+ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+ },
+ templates::repository_context::PromptCodeSnippet,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -609,6 +612,18 @@ impl AssistantPanel {
let project = pending_assist.project.clone();
+ let project_name = if let Some(project) = project.upgrade(cx) {
+ Some(
+ project
+ .read(cx)
+ .worktree_root_names(cx)
+ .collect::<Vec<&str>>()
+ .join("/"),
+ )
+ } else {
+ None
+ };
+
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into());
@@ -646,7 +661,6 @@ impl AssistantPanel {
None
};
- let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string();
let snippets = if retrieve_context {
@@ -668,14 +682,7 @@ impl AssistantPanel {
let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new();
for result in search_results.await {
- snippets.push(PromptCodeSnippet::new(result, &cx));
-
- // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
- // buffer
- // .snapshot()
- // .text_for_range(result.range)
- // .collect::<String>()
- // }));
+ snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
}
snippets
});
@@ -696,11 +703,11 @@ impl AssistantPanel {
generate_content_prompt(
user_prompt,
language_name,
- &buffer,
+ buffer,
range,
- codegen_kind,
snippets,
model_name,
+ project_name,
)
});
@@ -717,7 +724,8 @@ impl AssistantPanel {
}
cx.spawn(|_, mut cx| async move {
- let prompt = prompt.await;
+ // I Don't know if we want to return a ? here.
+ let prompt = prompt.await?;
messages.push(RequestMessage {
role: Role::User,
@@ -729,6 +737,7 @@ impl AssistantPanel {
stream: true,
};
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+ anyhow::Ok(())
})
.detach();
}
@@ -1,60 +1,13 @@
-use crate::codegen::CodegenKind;
-use gpui::AsyncAppContext;
+use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::file_context::FileContext;
+use ai::templates::generate::GenerateInlineContent;
+use ai::templates::preamble::EngineerPreamble;
+use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use semantic_index::SearchResult;
use std::cmp::{self, Reverse};
-use std::fmt::Write;
use std::ops::Range;
-use std::path::PathBuf;
-use tiktoken_rs::ChatCompletionRequestMessage;
-
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
- let (content, language_name, file_path) =
- search_result.buffer.read_with(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot
- .text_for_range(search_result.range.clone())
- .collect::<String>();
-
- let language_name = buffer
- .language()
- .and_then(|language| Some(language.name().to_string()));
-
- let file_path = buffer
- .file()
- .and_then(|file| Some(file.path().to_path_buf()));
-
- (content, language_name, file_path)
- });
-
- PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- }
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .and_then(|path| Some(path.to_string_lossy().to_string()))
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
+use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@@ -170,138 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
- buffer: &BufferSnapshot,
- range: Range<impl ToOffset>,
- kind: CodegenKind,
+ buffer: BufferSnapshot,
+ range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
-) -> String {
- const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
- const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
-
- let mut prompts = Vec::new();
- let range = range.to_offset(buffer);
-
- // General Preamble
- if let Some(language_name) = language_name {
- prompts.push(format!("You're an expert {language_name} engineer.\n"));
- } else {
- prompts.push("You're an expert engineer.\n".to_string());
- }
-
- // Snippets
- let mut snippet_position = prompts.len() - 1;
-
- let mut content = String::new();
- content.extend(buffer.text_for_range(0..range.start));
- if range.start == range.end {
- content.push_str("<|START|>");
- } else {
- content.push_str("<|START|");
- }
- content.extend(buffer.text_for_range(range.clone()));
- if range.start != range.end {
- content.push_str("|END|>");
- }
- content.extend(buffer.text_for_range(range.end..buffer.len()));
-
- prompts.push("The file you are currently working on has the following content:\n".to_string());
-
- if let Some(language_name) = language_name {
- let language_name = language_name.to_lowercase();
- prompts.push(format!("```{language_name}\n{content}\n```"));
+ project_name: Option<String>,
+) -> anyhow::Result<String> {
+ // Using new Prompt Templates
+ let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
+ let lang_name = if let Some(language_name) = language_name {
+ Some(language_name.to_string())
} else {
- prompts.push(format!("```\n{content}\n```"));
- }
-
- match kind {
- CodegenKind::Generate { position: _ } => {
- prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
- prompts
- .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
- prompts.push(
- "Text can't be replaced, so assume your answer will be inserted at the cursor."
- .to_string(),
- );
- prompts.push(format!(
- "Generate text based on the users prompt: {user_prompt}"
- ));
- }
- CodegenKind::Transform { range: _ } => {
- prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
- prompts.push(format!(
- "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
- ));
- prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
- }
- }
-
- if let Some(language_name) = language_name {
- prompts.push(format!(
- "Your answer MUST always and only be valid {language_name}"
- ));
- }
- prompts.push("Never make remarks about the output.".to_string());
- prompts.push("Do not return any text, except the generated code.".to_string());
- prompts.push("Always wrap your code in a Markdown block".to_string());
-
- let current_messages = [ChatCompletionRequestMessage {
- role: "user".to_string(),
- content: Some(prompts.join("\n")),
- function_call: None,
- name: None,
- }];
-
- let mut remaining_token_count = if let Ok(current_token_count) =
- tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
- {
- let max_token_count = tiktoken_rs::model::get_context_size(model);
- let intermediate_token_count = if max_token_count > current_token_count {
- max_token_count - current_token_count
- } else {
- 0
- };
-
- if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
- 0
- } else {
- intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
- }
- } else {
- // If tiktoken fails to count token count, assume we have no space remaining.
- 0
+ None
};
- // TODO:
- // - add repository name to snippet
- // - add file path
- // - add language
- if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
- let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
-
- for search_result in search_results {
- let mut snippet_prompt = template.to_string();
- let snippet = search_result.to_string();
- writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
-
- let token_count = encoding
- .encode_with_special_tokens(snippet_prompt.as_str())
- .len();
- if token_count <= remaining_token_count {
- if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
- prompts.insert(snippet_position, snippet_prompt);
- snippet_position += 1;
- remaining_token_count -= token_count;
- // If you have already added the template to the prompt, remove the template.
- template = "";
- }
- } else {
- break;
- }
- }
- }
+ let args = PromptArguments {
+ model: openai_model,
+ language_name: lang_name.clone(),
+ project_name,
+ snippets: search_results.clone(),
+ reserved_tokens: 1000,
+ buffer: Some(buffer),
+ selected_range: Some(range),
+ user_prompt: Some(user_prompt.clone()),
+ };
- prompts.join("\n")
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(RepositoryContext {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(FileContext {}),
+ ),
+ (
+ PromptPriority::Mandatory,
+ Box::new(GenerateInlineContent {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+ let (prompt, _) = chain.generate(true)?;
+
+ anyhow::Ok(prompt)
}
#[cfg(test)]