update the assistant panel to use new prompt templates

KCaverly created

Change summary

crates/ai/src/templates/base.rs         |   4 
crates/ai/src/templates/file_context.rs |  10 
crates/ai/src/templates/preamble.rs     |   2 
crates/assistant/src/assistant_panel.rs |  17 ++
crates/assistant/src/prompts.rs         | 146 ++------------------------
5 files changed, 33 insertions(+), 146 deletions(-)

Detailed changes

crates/ai/src/templates/base.rs 🔗

@@ -90,10 +90,6 @@ impl PromptChain {
             if let Some((template_prompt, prompt_token_count)) =
                 template.generate(&self.args, tokens_outstanding).log_err()
             {
-                println!(
-                    "GENERATED PROMPT ({:?}): {:?}",
-                    &prompt_token_count, &template_prompt
-                );
                 if template_prompt != "" {
                     prompts[idx] = template_prompt;
 

crates/ai/src/templates/file_context.rs 🔗

@@ -44,22 +44,22 @@ impl PromptTemplate for FileContext {
                 .unwrap();
 
                 if start == end {
-                    writeln!(prompt, "<|START|>").unwrap();
+                    write!(prompt, "<|START|>").unwrap();
                 } else {
-                    writeln!(prompt, "<|START|").unwrap();
+                    write!(prompt, "<|START|").unwrap();
                 }
 
-                writeln!(
+                write!(
                     prompt,
                     "{}",
                     buffer.text_for_range(start..end).collect::<String>()
                 )
                 .unwrap();
                 if start != end {
-                    writeln!(prompt, "|END|>").unwrap();
+                    write!(prompt, "|END|>").unwrap();
                 }
 
-                writeln!(
+                write!(
                     prompt,
                     "{}",
                     buffer.text_for_range(end..buffer.len()).collect::<String>()

crates/ai/src/templates/preamble.rs 🔗

@@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble {
 
         if let Some(project_name) = args.project_name.clone() {
             prompts.push(format!(
-                "You are currently working inside the '{project_name}' in Zed the code editor."
+                "You are currently working inside the '{project_name}' project in code editor Zed."
             ));
         }
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -612,6 +612,18 @@ impl AssistantPanel {
 
         let project = pending_assist.project.clone();
 
+        let project_name = if let Some(project) = project.upgrade(cx) {
+            Some(
+                project
+                    .read(cx)
+                    .worktree_root_names(cx)
+                    .collect::<Vec<&str>>()
+                    .join("/"),
+            )
+        } else {
+            None
+        };
+
         self.inline_prompt_history
             .retain(|prompt| prompt != user_prompt);
         self.inline_prompt_history.push_back(user_prompt.into());
@@ -649,7 +661,6 @@ impl AssistantPanel {
             None
         };
 
-        let codegen_kind = codegen.read(cx).kind().clone();
         let user_prompt = user_prompt.to_string();
 
         let snippets = if retrieve_context {
@@ -692,11 +703,11 @@ impl AssistantPanel {
             generate_content_prompt(
                 user_prompt,
                 language_name,
-                &buffer,
+                buffer,
                 range,
-                codegen_kind,
                 snippets,
                 model_name,
+                project_name,
             )
         });
 

crates/assistant/src/prompts.rs 🔗

@@ -1,6 +1,8 @@
 use crate::codegen::CodegenKind;
 use ai::models::{LanguageModel, OpenAILanguageModel};
 use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::file_context::FileContext;
+use ai::templates::generate::GenerateInlineContent;
 use ai::templates::preamble::EngineerPreamble;
 use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
@@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
 pub fn generate_content_prompt(
     user_prompt: String,
     language_name: Option<&str>,
-    buffer: &BufferSnapshot,
-    range: Range<impl ToOffset>,
-    kind: CodegenKind,
+    buffer: BufferSnapshot,
+    range: Range<usize>,
     search_results: Vec<PromptCodeSnippet>,
     model: &str,
+    project_name: Option<String>,
 ) -> anyhow::Result<String> {
     // Using new Prompt Templates
     let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
@@ -141,146 +143,24 @@ pub fn generate_content_prompt(
     let args = PromptArguments {
         model: openai_model,
         language_name: lang_name.clone(),
-        project_name: None,
+        project_name,
         snippets: search_results.clone(),
         reserved_tokens: 1000,
+        buffer: Some(buffer),
+        selected_range: Some(range),
+        user_prompt: Some(user_prompt.clone()),
     };
 
     let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
         (PromptPriority::High, Box::new(EngineerPreamble {})),
         (PromptPriority::Low, Box::new(RepositoryContext {})),
+        (PromptPriority::Medium, Box::new(FileContext {})),
+        (PromptPriority::High, Box::new(GenerateInlineContent {})),
     ];
     let chain = PromptChain::new(args, templates);
+    let (prompt, _) = chain.generate(true)?;
 
-    let prompt = chain.generate(true)?;
-    println!("{:?}", prompt);
-
-    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
-    const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
-
-    let mut prompts = Vec::new();
-    let range = range.to_offset(buffer);
-
-    // General Preamble
-    if let Some(language_name) = language_name.clone() {
-        prompts.push(format!("You're an expert {language_name} engineer.\n"));
-    } else {
-        prompts.push("You're an expert engineer.\n".to_string());
-    }
-
-    // Snippets
-    let mut snippet_position = prompts.len() - 1;
-
-    let mut content = String::new();
-    content.extend(buffer.text_for_range(0..range.start));
-    if range.start == range.end {
-        content.push_str("<|START|>");
-    } else {
-        content.push_str("<|START|");
-    }
-    content.extend(buffer.text_for_range(range.clone()));
-    if range.start != range.end {
-        content.push_str("|END|>");
-    }
-    content.extend(buffer.text_for_range(range.end..buffer.len()));
-
-    prompts.push("The file you are currently working on has the following content:\n".to_string());
-
-    if let Some(language_name) = language_name {
-        let language_name = language_name.to_lowercase();
-        prompts.push(format!("```{language_name}\n{content}\n```"));
-    } else {
-        prompts.push(format!("```\n{content}\n```"));
-    }
-
-    match kind {
-        CodegenKind::Generate { position: _ } => {
-            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
-            prompts
-                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
-            prompts.push(
-                "Text can't be replaced, so assume your answer will be inserted at the cursor."
-                    .to_string(),
-            );
-            prompts.push(format!(
-                "Generate text based on the users prompt: {user_prompt}"
-            ));
-        }
-        CodegenKind::Transform { range: _ } => {
-            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
-            prompts.push(format!(
-                "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
-            ));
-            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
-        }
-    }
-
-    if let Some(language_name) = language_name {
-        prompts.push(format!(
-            "Your answer MUST always and only be valid {language_name}"
-        ));
-    }
-    prompts.push("Never make remarks about the output.".to_string());
-    prompts.push("Do not return any text, except the generated code.".to_string());
-    prompts.push("Always wrap your code in a Markdown block".to_string());
-
-    let current_messages = [ChatCompletionRequestMessage {
-        role: "user".to_string(),
-        content: Some(prompts.join("\n")),
-        function_call: None,
-        name: None,
-    }];
-
-    let mut remaining_token_count = if let Ok(current_token_count) =
-        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
-    {
-        let max_token_count = tiktoken_rs::model::get_context_size(model);
-        let intermediate_token_count = if max_token_count > current_token_count {
-            max_token_count - current_token_count
-        } else {
-            0
-        };
-
-        if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
-            0
-        } else {
-            intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
-        }
-    } else {
-        // If tiktoken fails to count token count, assume we have no space remaining.
-        0
-    };
-
-    // TODO:
-    //   - add repository name to snippet
-    //   - add file path
-    //   - add language
-    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
-        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
-
-        for search_result in search_results {
-            let mut snippet_prompt = template.to_string();
-            let snippet = search_result.to_string();
-            writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
-
-            let token_count = encoding
-                .encode_with_special_tokens(snippet_prompt.as_str())
-                .len();
-            if token_count <= remaining_token_count {
-                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
-                    prompts.insert(snippet_position, snippet_prompt);
-                    snippet_position += 1;
-                    remaining_token_count -= token_count;
-                    // If you have already added the template to the prompt, remove the template.
-                    template = "";
-                }
-            } else {
-                break;
-            }
-        }
-    }
-
-    anyhow::Ok(prompts.join("\n"))
+    anyhow::Ok(prompt)
 }
 
 #[cfg(test)]