Cargo.lock 🔗
@@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28",
"gpui",
"isahc",
+ "language",
"lazy_static",
"log",
"matrixmultiply",
KCaverly created
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/prompts.rs | 149 +++++++++++++++++++++
crates/ai/src/templates.rs | 76 ----------
crates/ai/src/templates/base.rs | 112 +++++++++++++++
crates/ai/src/templates/mod.rs | 3
crates/ai/src/templates/preamble.rs | 34 ++++
crates/ai/src/templates/repository_context.rs | 49 ++++++
8 files changed, 349 insertions(+), 76 deletions(-)
@@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28",
"gpui",
"isahc",
+ "language",
"lazy_static",
"log",
"matrixmultiply",
@@ -11,6 +11,7 @@ doctest = false
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
+language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@@ -0,0 +1,149 @@
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+use std::{fmt::Write, ops::Range, path::PathBuf};
+
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+ let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ });
+
+ PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ }
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
+
+enum PromptFileType {
+ Text,
+ Code,
+}
+
+#[derive(Default)]
+struct PromptArguments {
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub model_name: String,
+}
+
+impl PromptArguments {
+ pub fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+trait PromptTemplate {
+ fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String;
+}
+
+struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
+ let mut prompt = String::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ writeln!(
+ prompt,
+ "You are an expert {} engineer.",
+ args.language_name.unwrap_or("".to_string())
+ )
+ .unwrap();
+ }
+ PromptFileType::Text => {
+ writeln!(prompt, "You are an expert engineer.").unwrap();
+ }
+ }
+
+ if let Some(project_name) = args.project_name {
+ writeln!(
+ prompt,
+ "You are currently working inside the '{project_name}' in Zed the code editor."
+ )
+ .unwrap();
+ }
+
+ prompt
+ }
+}
+
+struct RepositorySnippets {}
+
+impl PromptTemplate for RepositorySnippets {
+ fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
+ let mut prompt = String::new();
+
+ if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) {
+ let default_token_count =
+ tiktoken_rs::model::get_context_size(args.model_name.as_str());
+ let mut remaining_token_count = max_token_length.unwrap_or(default_token_count);
+
+ for snippet in args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = encoding
+ .encode_with_special_tokens(snippet_prompt.as_str())
+ .len();
+ if token_count <= remaining_token_count {
+ if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_token_count -= token_count;
+ template = "";
+ }
+ } else {
+ break;
+ }
+ }
+ }
+
+ prompt
+ }
+}
@@ -1,76 +0,0 @@
-use std::fmt::Write;
-
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-enum PromptFileType {
- Text,
- Code,
-}
-
-#[derive(Default)]
-struct PromptArguments {
- pub language_name: Option<String>,
- pub project_name: Option<String>,
- pub snippets: Vec<PromptCodeSnippet>,
-}
-
-impl PromptArguments {
- pub fn get_file_type(&self) -> PromptFileType {
- if self
- .language_name
- .as_ref()
- .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
- .unwrap_or(true)
- {
- PromptFileType::Code
- } else {
- PromptFileType::Text
- }
- }
-}
-
-trait PromptTemplate {
- fn generate(args: PromptArguments) -> String;
-}
-
-struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
- fn generate(args: PromptArguments) -> String {
- let mut prompt = String::new();
-
- match args.get_file_type() {
- PromptFileType::Code => {
- writeln!(
- prompt,
- "You are an expert {} engineer.",
- args.language_name.unwrap_or("".to_string())
- )
- .unwrap();
- }
- PromptFileType::Text => {
- writeln!(prompt, "You are an expert engineer.").unwrap();
- }
- }
-
- if let Some(project_name) = args.project_name {
- writeln!(
- prompt,
- "You are currently working inside the '{project_name}' in Zed the code editor."
- )
- .unwrap();
- }
-
- prompt
- }
-}
-
-struct RepositorySnippets {}
-
-impl PromptTemplate for RepositorySnippets {
- fn generate(args: PromptArguments) -> String {}
-}
@@ -0,0 +1,112 @@
+use std::cmp::Reverse;
+
+use crate::templates::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+ Text,
+ Code,
+}
+
+#[derive(Default)]
+pub struct PromptArguments {
+ pub model_name: String,
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub reserved_tokens: usize,
+}
+
+impl PromptArguments {
+ pub(crate) fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+pub trait PromptTemplate {
+ fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, PartialOrd, Ord)]
+pub enum PromptPriority {
+ Low,
+ Medium,
+ High,
+}
+
+pub struct PromptChain {
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+ pub fn new(
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+ ) -> Self {
+ // templates.sort_by(|a, b| a.0.cmp(&b.0));
+
+ PromptChain { args, templates }
+ }
+
+ pub fn generate(&self, truncate: bool) -> anyhow::Result<String> {
+ // Argsort based on Prompt Priority
+ let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+ sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+ println!("{:?}", sorted_indices);
+
+ let mut prompts = Vec::new();
+ for (_, template) in &self.templates {
+ prompts.push(template.generate(&self.args, None));
+ }
+
+ anyhow::Ok(prompts.join("\n"))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+
+ #[test]
+ pub fn test_prompt_chain() {
+ struct TestPromptTemplate {}
+ impl PromptTemplate for TestPromptTemplate {
+ fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+ "This is a test prompt template".to_string()
+ }
+ }
+
+ struct TestLowPriorityTemplate {}
+ impl PromptTemplate for TestLowPriorityTemplate {
+ fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+ "This is a low priority test prompt template".to_string()
+ }
+ }
+
+ let args = PromptArguments {
+ model_name: "gpt-4".to_string(),
+ ..Default::default()
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (PromptPriority::High, Box::new(TestPromptTemplate {})),
+ (PromptPriority::Medium, Box::new(TestLowPriorityTemplate {})),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let prompt = chain.generate(false);
+ println!("{:?}", prompt);
+ panic!();
+ }
+}
@@ -0,0 +1,3 @@
+pub mod base;
+pub mod preamble;
+pub mod repository_context;
@@ -0,0 +1,34 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(&self, args: &PromptArguments, max_token_length: Option<usize>) -> String {
+ let mut prompt = String::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ writeln!(
+ prompt,
+ "You are an expert {} engineer.",
+ args.language_name.clone().unwrap_or("".to_string())
+ )
+ .unwrap();
+ }
+ PromptFileType::Text => {
+ writeln!(prompt, "You are an expert engineer.").unwrap();
+ }
+ }
+
+ if let Some(project_name) = args.project_name.clone() {
+ writeln!(
+ prompt,
+ "You are currently working inside the '{project_name}' in Zed the code editor."
+ )
+ .unwrap();
+ }
+
+ prompt
+ }
+}
@@ -0,0 +1,49 @@
+use std::{ops::Range, path::PathBuf};
+
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+ let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ });
+
+ PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ }
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}