From aa1825681c60176d391ba497a9d28b0e5703fa60 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 18 Oct 2023 14:20:12 -0400 Subject: [PATCH] update the assistant panel to use new prompt templates --- crates/ai/src/templates/base.rs | 4 - crates/ai/src/templates/file_context.rs | 10 +- crates/ai/src/templates/preamble.rs | 2 +- crates/assistant/src/assistant_panel.rs | 17 ++- crates/assistant/src/prompts.rs | 146 +++--------------------- 5 files changed, 33 insertions(+), 146 deletions(-) diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index db437a029cd73ec620385362ed83061103d82078..aaf08d755efb4746192bb75e64f0f7cc7e7a4e83 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -90,10 +90,6 @@ impl PromptChain { if let Some((template_prompt, prompt_token_count)) = template.generate(&self.args, tokens_outstanding).log_err() { - println!( - "GENERATED PROMPT ({:?}): {:?}", - &prompt_token_count, &template_prompt - ); if template_prompt != "" { prompts[idx] = template_prompt; diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 68bf424db1ddb6c3cd11907688ee5080e8f41c5f..6d0630504983fbe90597525ea8f49dd23e0a1036 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -44,22 +44,22 @@ impl PromptTemplate for FileContext { .unwrap(); if start == end { - writeln!(prompt, "<|START|>").unwrap(); + write!(prompt, "<|START|>").unwrap(); } else { - writeln!(prompt, "<|START|").unwrap(); + write!(prompt, "<|START|").unwrap(); } - writeln!( + write!( prompt, "{}", buffer.text_for_range(start..end).collect::() ) .unwrap(); if start != end { - writeln!(prompt, "|END|>").unwrap(); + write!(prompt, "|END|>").unwrap(); } - writeln!( + write!( prompt, "{}", buffer.text_for_range(end..buffer.len()).collect::() diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index 5834fa1b21b2011fbbc82d781493c4e4e523b685..9eabaaeb97fe4216c6bac44cf4eabfb7c129ecf2 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble { if let Some(project_name) = args.project_name.clone() { prompts.push(format!( - "You are currently working inside the '{project_name}' in Zed the code editor." + "You are currently working inside the '{project_name}' project in code editor Zed." )); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 3a0f05379e1d06ce9900bcb2179ed1a347c96f70..4dd4e2a98315c042d74c7ef6bde78200287ab6ad 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -612,6 +612,18 @@ impl AssistantPanel { let project = pending_assist.project.clone(); + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -649,7 +661,6 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); let user_prompt = user_prompt.to_string(); let snippets = if retrieve_context { @@ -692,11 +703,11 @@ impl AssistantPanel { generate_content_prompt( user_prompt, language_name, - &buffer, + buffer, range, - codegen_kind, snippets, model_name, + project_name, ) }); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 333742aa0525afe4d362523868ada9cb187cc363..1457d28fff22c83c29090dcded37aa9a915918bd 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,6 +1,8 @@ use crate::codegen::CodegenKind; use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::file_context::FileContext; +use ai::templates::generate::GenerateInlineContent; use ai::templates::preamble::EngineerPreamble; use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; @@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, + buffer: BufferSnapshot, + range: Range, search_results: Vec, model: &str, + project_name: Option, ) -> anyhow::Result { // Using new Prompt Templates let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); @@ -141,146 +143,24 @@ pub fn generate_content_prompt( let args = PromptArguments { model: openai_model, language_name: lang_name.clone(), - project_name: None, + project_name, snippets: search_results.clone(), reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), }; let templates: Vec<(PromptPriority, Box)> = vec![ (PromptPriority::High, Box::new(EngineerPreamble {})), (PromptPriority::Low, Box::new(RepositoryContext {})), + (PromptPriority::Medium, Box::new(FileContext {})), + (PromptPriority::High, Box::new(GenerateInlineContent {})), ]; let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; - let prompt = chain.generate(true)?; - println!("{:?}", prompt); - - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; - - let mut prompts = Vec::new(); - let range = range.to_offset(buffer); - - // General Preamble - if let Some(language_name) = language_name.clone() { - prompts.push(format!("You're an expert {language_name} engineer.\n")); - } else { - prompts.push("You're an expert engineer.\n".to_string()); - } - - // Snippets - let mut snippet_position = prompts.len() - 1; - - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); - - prompts.push("The file you are currently working on has the following content:\n".to_string()); - - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - prompts.push(format!("```{language_name}\n{content}\n```")); - } else { - prompts.push(format!("```\n{content}\n```")); - } - - match kind { - CodegenKind::Generate { position: _ } => { - prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string()); - prompts - .push("Assume the cursor is located where the `<|START|` marker is.".to_string()); - prompts.push( - "Text can't be replaced, so assume your answer will be inserted at the cursor." - .to_string(), - ); - prompts.push(format!( - "Generate text based on the users prompt: {user_prompt}" - )); - } - CodegenKind::Transform { range: _ } => { - prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string()); - prompts.push(format!( - "Modify the users code selected text based upon the users prompt: '{user_prompt}'" - )); - prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string()); - } - } - - if let Some(language_name) = language_name { - prompts.push(format!( - "Your answer MUST always and only be valid {language_name}" - )); - } - prompts.push("Never make remarks about the output.".to_string()); - prompts.push("Do not return any text, except the generated code.".to_string()); - prompts.push("Always wrap your code in a Markdown block".to_string()); - - let current_messages = [ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(prompts.join("\n")), - function_call: None, - name: None, - }]; - - let mut remaining_token_count = if let Ok(current_token_count) = - tiktoken_rs::num_tokens_from_messages(model, ¤t_messages) - { - let max_token_count = tiktoken_rs::model::get_context_size(model); - let intermediate_token_count = if max_token_count > current_token_count { - max_token_count - current_token_count - } else { - 0 - }; - - if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION { - 0 - } else { - intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION - } - } else { - // If tiktoken fails to count token count, assume we have no space remaining. - 0 - }; - - // TODO: - // - add repository name to snippet - // - add file path - // - add language - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; - - for search_result in search_results { - let mut snippet_prompt = template.to_string(); - let snippet = search_result.to_string(); - writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap(); - - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - prompts.insert(snippet_position, snippet_prompt); - snippet_position += 1; - remaining_token_count -= token_count; - // If you have already added the template to the prompt, remove the template. - template = ""; - } - } else { - break; - } - } - } - - anyhow::Ok(prompts.join("\n")) + anyhow::Ok(prompt) } #[cfg(test)]