From 02853bbd606dc87a638bd2ca01a5232203069499 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 17 Oct 2023 17:29:07 -0400 Subject: [PATCH] added prompt template for repository context --- crates/ai/src/models.rs | 8 +- crates/ai/src/prompts.rs | 149 ------------------ crates/ai/src/templates/preamble.rs | 6 +- crates/ai/src/templates/repository_context.rs | 47 +++++- crates/assistant/src/assistant_panel.rs | 22 ++- crates/assistant/src/prompts.rs | 87 ++++------ 6 files changed, 96 insertions(+), 223 deletions(-) delete mode 100644 crates/ai/src/prompts.rs diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index 4fe96d44f33f10ad1e6ee8572a8cceb02fca8fd4..69e73e9b56ec7db7023983a67f5ee994c97c5725 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -9,16 +9,16 @@ pub trait LanguageModel { fn capacity(&self) -> anyhow::Result; } -struct OpenAILanguageModel { +pub struct OpenAILanguageModel { name: String, bpe: Option, } impl OpenAILanguageModel { - pub fn load(model_name: String) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err(); + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); OpenAILanguageModel { - name: model_name, + name: model_name.to_string(), bpe, } } diff --git a/crates/ai/src/prompts.rs b/crates/ai/src/prompts.rs deleted file mode 100644 index 6d2c0629fa08e2d464adc5bf1d48c44659da8545..0000000000000000000000000000000000000000 --- a/crates/ai/src/prompts.rs +++ /dev/null @@ -1,149 +0,0 @@ -use gpui::{AsyncAppContext, ModelHandle}; -use language::{Anchor, Buffer}; -use std::{fmt::Write, ops::Range, path::PathBuf}; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot.text_for_range(range.clone()).collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} - -enum PromptFileType { - Text, - Code, -} - -#[derive(Default)] -struct PromptArguments { - pub language_name: Option, - pub project_name: Option, - pub snippets: Vec, - pub model_name: String, -} - -impl PromptArguments { - pub fn get_file_type(&self) -> PromptFileType { - if self - .language_name - .as_ref() - .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) - .unwrap_or(true) - { - PromptFileType::Code - } else { - PromptFileType::Text - } - } -} - -trait PromptTemplate { - fn generate(args: PromptArguments, max_token_length: Option) -> String; -} - -struct EngineerPreamble {} - -impl PromptTemplate for EngineerPreamble { - fn generate(args: PromptArguments, max_token_length: Option) -> String { - let mut prompt = String::new(); - - match args.get_file_type() { - PromptFileType::Code => { - writeln!( - prompt, - "You are an expert {} engineer.", - args.language_name.unwrap_or("".to_string()) - ) - .unwrap(); - } - PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); - } - } - - if let Some(project_name) = args.project_name { - writeln!( - prompt, - "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); - } - - prompt - } -} - -struct RepositorySnippets {} - -impl PromptTemplate for RepositorySnippets { - fn generate(args: PromptArguments, max_token_length: Option) -> String { - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; - let mut prompt = String::new(); - - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) { - let default_token_count = - tiktoken_rs::model::get_context_size(args.model_name.as_str()); - let mut remaining_token_count = max_token_length.unwrap_or(default_token_count); - - for snippet in args.snippets { - let mut snippet_prompt = template.to_string(); - let content = snippet.to_string(); - writeln!(snippet_prompt, "{content}").unwrap(); - - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - writeln!(prompt, "{snippet_prompt}").unwrap(); - remaining_token_count -= token_count; - template = ""; - } - } else { - break; - } - } - } - - prompt - } -} diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index f395dbf8beeabde2a703214cc0426900908be990..5834fa1b21b2011fbbc82d781493c4e4e523b685 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -1,7 +1,7 @@ use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; use std::fmt::Write; -struct EngineerPreamble {} +pub struct EngineerPreamble {} impl PromptTemplate for EngineerPreamble { fn generate( @@ -14,8 +14,8 @@ impl PromptTemplate for EngineerPreamble { match args.get_file_type() { PromptFileType::Code => { prompts.push(format!( - "You are an expert {} engineer.", - args.language_name.clone().unwrap_or("".to_string()) + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " )); } PromptFileType::Text => { diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs index f9c2253c654de8da59ffa99ec07a12233b121d01..7dd1647c440a228b5fc2c8317fe35e0931d4c1a9 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/templates/repository_context.rs @@ -1,8 +1,11 @@ +use crate::templates::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; use std::{ops::Range, path::PathBuf}; use gpui::{AsyncAppContext, ModelHandle}; use language::{Anchor, Buffer}; +#[derive(Clone)] pub struct PromptCodeSnippet { path: Option, language_name: Option, @@ -17,7 +20,7 @@ impl PromptCodeSnippet { let language_name = buffer .language() - .and_then(|language| Some(language.name().to_string())); + .and_then(|language| Some(language.name().to_string().to_lowercase())); let file_path = buffer .file() @@ -47,3 +50,45 @@ impl ToString for PromptCodeSnippet { format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") } } + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let mut template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index e8edf70498c14e7324073b732cae6b887f3131f9..06de5c135fdf535e4f253ec6f92ef2d449f769a2 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,12 +1,15 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, - prompts::{generate_content_prompt, PromptCodeSnippet}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, + templates::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -668,14 +671,7 @@ impl AssistantPanel { let snippets = cx.spawn(|_, cx| async move { let mut snippets = Vec::new(); for result in search_results.await { - snippets.push(PromptCodeSnippet::new(result, &cx)); - - // snippets.push(result.buffer.read_with(&cx, |buffer, _| { - // buffer - // .snapshot() - // .text_for_range(result.range) - // .collect::() - // })); + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); } snippets }); @@ -717,7 +713,8 @@ impl AssistantPanel { } cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, @@ -729,6 +726,7 @@ impl AssistantPanel { stream: true, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 7aafe75920b351e7244f14858036a4aa9af64f6f..e33a6e4022e87c99b899b0f492d0c25e1514cb4f 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,61 +1,15 @@ use crate::codegen::CodegenKind; -use gpui::AsyncAppContext; +use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::preamble::EngineerPreamble; +use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; -use semantic_index::SearchResult; use std::cmp::{self, Reverse}; use std::fmt::Write; use std::ops::Range; -use std::path::PathBuf; +use std::sync::Arc; use tiktoken_rs::ChatCompletionRequestMessage; -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = - search_result.buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot - .text_for_range(search_result.range.clone()) - .collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} - #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { #[derive(Debug)] @@ -175,7 +129,32 @@ pub fn generate_content_prompt( kind: CodegenKind, search_results: Vec, model: &str, -) -> String { +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) + } else { + None + }; + + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name: None, + snippets: search_results.clone(), + reserved_tokens: 1000, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::High, Box::new(EngineerPreamble {})), + (PromptPriority::Low, Box::new(RepositoryContext {})), + ]; + let chain = PromptChain::new(args, templates); + + let prompt = chain.generate(true)?; + println!("{:?}", prompt); + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; @@ -183,7 +162,7 @@ pub fn generate_content_prompt( let range = range.to_offset(buffer); // General Preamble - if let Some(language_name) = language_name { + if let Some(language_name) = language_name.clone() { prompts.push(format!("You're an expert {language_name} engineer.\n")); } else { prompts.push("You're an expert engineer.\n".to_string()); @@ -297,7 +276,7 @@ pub fn generate_content_prompt( } } - prompts.join("\n") + anyhow::Ok(prompts.join("\n")) } #[cfg(test)]