added prompt template for repository context

KCaverly created

Change summary

crates/ai/src/models.rs                       |   8 
crates/ai/src/prompts.rs                      | 149 ---------------------
crates/ai/src/templates/preamble.rs           |   6 
crates/ai/src/templates/repository_context.rs |  47 ++++++
crates/assistant/src/assistant_panel.rs       |  22 +-
crates/assistant/src/prompts.rs               |  87 ++++-------
6 files changed, 96 insertions(+), 223 deletions(-)

Detailed changes

crates/ai/src/models.rs 🔗

@@ -9,16 +9,16 @@ pub trait LanguageModel {
     fn capacity(&self) -> anyhow::Result<usize>;
 }
 
-struct OpenAILanguageModel {
+pub struct OpenAILanguageModel {
     name: String,
     bpe: Option<CoreBPE>,
 }
 
 impl OpenAILanguageModel {
-    pub fn load(model_name: String) -> Self {
-        let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err();
+    pub fn load(model_name: &str) -> Self {
+        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
         OpenAILanguageModel {
-            name: model_name,
+            name: model_name.to_string(),
             bpe,
         }
     }

crates/ai/src/prompts.rs 🔗

@@ -1,149 +0,0 @@
-use gpui::{AsyncAppContext, ModelHandle};
-use language::{Anchor, Buffer};
-use std::{fmt::Write, ops::Range, path::PathBuf};
-
-pub struct PromptCodeSnippet {
-    path: Option<PathBuf>,
-    language_name: Option<String>,
-    content: String,
-}
-
-impl PromptCodeSnippet {
-    pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
-        let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
-            let snapshot = buffer.snapshot();
-            let content = snapshot.text_for_range(range.clone()).collect::<String>();
-
-            let language_name = buffer
-                .language()
-                .and_then(|language| Some(language.name().to_string()));
-
-            let file_path = buffer
-                .file()
-                .and_then(|file| Some(file.path().to_path_buf()));
-
-            (content, language_name, file_path)
-        });
-
-        PromptCodeSnippet {
-            path: file_path,
-            language_name,
-            content,
-        }
-    }
-}
-
-impl ToString for PromptCodeSnippet {
-    fn to_string(&self) -> String {
-        let path = self
-            .path
-            .as_ref()
-            .and_then(|path| Some(path.to_string_lossy().to_string()))
-            .unwrap_or("".to_string());
-        let language_name = self.language_name.clone().unwrap_or("".to_string());
-        let content = self.content.clone();
-
-        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
-    }
-}
-
-enum PromptFileType {
-    Text,
-    Code,
-}
-
-#[derive(Default)]
-struct PromptArguments {
-    pub language_name: Option<String>,
-    pub project_name: Option<String>,
-    pub snippets: Vec<PromptCodeSnippet>,
-    pub model_name: String,
-}
-
-impl PromptArguments {
-    pub fn get_file_type(&self) -> PromptFileType {
-        if self
-            .language_name
-            .as_ref()
-            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
-            .unwrap_or(true)
-        {
-            PromptFileType::Code
-        } else {
-            PromptFileType::Text
-        }
-    }
-}
-
-trait PromptTemplate {
-    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String;
-}
-
-struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
-    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
-        let mut prompt = String::new();
-
-        match args.get_file_type() {
-            PromptFileType::Code => {
-                writeln!(
-                    prompt,
-                    "You are an expert {} engineer.",
-                    args.language_name.unwrap_or("".to_string())
-                )
-                .unwrap();
-            }
-            PromptFileType::Text => {
-                writeln!(prompt, "You are an expert engineer.").unwrap();
-            }
-        }
-
-        if let Some(project_name) = args.project_name {
-            writeln!(
-                prompt,
-                "You are currently working inside the '{project_name}' in Zed the code editor."
-            )
-            .unwrap();
-        }
-
-        prompt
-    }
-}
-
-struct RepositorySnippets {}
-
-impl PromptTemplate for RepositorySnippets {
-    fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
-        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
-        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
-        let mut prompt = String::new();
-
-        if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) {
-            let default_token_count =
-                tiktoken_rs::model::get_context_size(args.model_name.as_str());
-            let mut remaining_token_count = max_token_length.unwrap_or(default_token_count);
-
-            for snippet in args.snippets {
-                let mut snippet_prompt = template.to_string();
-                let content = snippet.to_string();
-                writeln!(snippet_prompt, "{content}").unwrap();
-
-                let token_count = encoding
-                    .encode_with_special_tokens(snippet_prompt.as_str())
-                    .len();
-                if token_count <= remaining_token_count {
-                    if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
-                        writeln!(prompt, "{snippet_prompt}").unwrap();
-                        remaining_token_count -= token_count;
-                        template = "";
-                    }
-                } else {
-                    break;
-                }
-            }
-        }
-
-        prompt
-    }
-}

crates/ai/src/templates/preamble.rs 🔗

@@ -1,7 +1,7 @@
 use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
 use std::fmt::Write;
 
-struct EngineerPreamble {}
+pub struct EngineerPreamble {}
 
 impl PromptTemplate for EngineerPreamble {
     fn generate(
@@ -14,8 +14,8 @@ impl PromptTemplate for EngineerPreamble {
         match args.get_file_type() {
             PromptFileType::Code => {
                 prompts.push(format!(
-                    "You are an expert {} engineer.",
-                    args.language_name.clone().unwrap_or("".to_string())
+                    "You are an expert {}engineer.",
+                    args.language_name.clone().unwrap_or("".to_string()) + " "
                 ));
             }
             PromptFileType::Text => {

crates/ai/src/templates/repository_context.rs 🔗

@@ -1,8 +1,11 @@
+use crate::templates::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
 use std::{ops::Range, path::PathBuf};
 
 use gpui::{AsyncAppContext, ModelHandle};
 use language::{Anchor, Buffer};
 
+#[derive(Clone)]
 pub struct PromptCodeSnippet {
     path: Option<PathBuf>,
     language_name: Option<String>,
@@ -17,7 +20,7 @@ impl PromptCodeSnippet {
 
             let language_name = buffer
                 .language()
-                .and_then(|language| Some(language.name().to_string()));
+                .and_then(|language| Some(language.name().to_string().to_lowercase()));
 
             let file_path = buffer
                 .file()
@@ -47,3 +50,45 @@ impl ToString for PromptCodeSnippet {
         format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
     }
 }
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+        let mut prompt = String::new();
+
+        let mut remaining_tokens = max_token_length.clone();
+        let seperator_token_length = args.model.count_tokens("\n")?;
+        for snippet in &args.snippets {
+            let mut snippet_prompt = template.to_string();
+            let content = snippet.to_string();
+            writeln!(snippet_prompt, "{content}").unwrap();
+
+            let token_count = args.model.count_tokens(&snippet_prompt)?;
+            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+                if let Some(tokens_left) = remaining_tokens {
+                    if tokens_left >= token_count {
+                        writeln!(prompt, "{snippet_prompt}").unwrap();
+                        remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+                        {
+                            Some(tokens_left - token_count - seperator_token_length)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                } else {
+                    writeln!(prompt, "{snippet_prompt}").unwrap();
+                }
+            }
+        }
+
+        let total_token_count = args.model.count_tokens(&prompt)?;
+        anyhow::Ok((prompt, total_token_count))
+    }
+}

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,12 +1,15 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
     codegen::{self, Codegen, CodegenKind},
-    prompts::{generate_content_prompt, PromptCodeSnippet},
+    prompts::generate_content_prompt,
     MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
     SavedMessage,
 };
-use ai::completion::{
-    stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+    completion::{
+        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+    },
+    templates::repository_context::PromptCodeSnippet,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -668,14 +671,7 @@ impl AssistantPanel {
             let snippets = cx.spawn(|_, cx| async move {
                 let mut snippets = Vec::new();
                 for result in search_results.await {
-                    snippets.push(PromptCodeSnippet::new(result, &cx));
-
-                    // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
-                    //     buffer
-                    //         .snapshot()
-                    //         .text_for_range(result.range)
-                    //         .collect::<String>()
-                    // }));
+                    snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
                 }
                 snippets
             });
@@ -717,7 +713,8 @@ impl AssistantPanel {
         }
 
         cx.spawn(|_, mut cx| async move {
-            let prompt = prompt.await;
+            // I Don't know if we want to return a ? here.
+            let prompt = prompt.await?;
 
             messages.push(RequestMessage {
                 role: Role::User,
@@ -729,6 +726,7 @@ impl AssistantPanel {
                 stream: true,
             };
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+            anyhow::Ok(())
         })
         .detach();
     }

crates/assistant/src/prompts.rs 🔗

@@ -1,61 +1,15 @@
 use crate::codegen::CodegenKind;
-use gpui::AsyncAppContext;
+use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::preamble::EngineerPreamble;
+use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use semantic_index::SearchResult;
 use std::cmp::{self, Reverse};
 use std::fmt::Write;
 use std::ops::Range;
-use std::path::PathBuf;
+use std::sync::Arc;
 use tiktoken_rs::ChatCompletionRequestMessage;
 
-pub struct PromptCodeSnippet {
-    path: Option<PathBuf>,
-    language_name: Option<String>,
-    content: String,
-}
-
-impl PromptCodeSnippet {
-    pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
-        let (content, language_name, file_path) =
-            search_result.buffer.read_with(cx, |buffer, _| {
-                let snapshot = buffer.snapshot();
-                let content = snapshot
-                    .text_for_range(search_result.range.clone())
-                    .collect::<String>();
-
-                let language_name = buffer
-                    .language()
-                    .and_then(|language| Some(language.name().to_string()));
-
-                let file_path = buffer
-                    .file()
-                    .and_then(|file| Some(file.path().to_path_buf()));
-
-                (content, language_name, file_path)
-            });
-
-        PromptCodeSnippet {
-            path: file_path,
-            language_name,
-            content,
-        }
-    }
-}
-
-impl ToString for PromptCodeSnippet {
-    fn to_string(&self) -> String {
-        let path = self
-            .path
-            .as_ref()
-            .and_then(|path| Some(path.to_string_lossy().to_string()))
-            .unwrap_or("".to_string());
-        let language_name = self.language_name.clone().unwrap_or("".to_string());
-        let content = self.content.clone();
-
-        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
-    }
-}
-
 #[allow(dead_code)]
 fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
     #[derive(Debug)]
@@ -175,7 +129,32 @@ pub fn generate_content_prompt(
     kind: CodegenKind,
     search_results: Vec<PromptCodeSnippet>,
     model: &str,
-) -> String {
+) -> anyhow::Result<String> {
+    // Using new Prompt Templates
+    let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
+    let lang_name = if let Some(language_name) = language_name {
+        Some(language_name.to_string())
+    } else {
+        None
+    };
+
+    let args = PromptArguments {
+        model: openai_model,
+        language_name: lang_name.clone(),
+        project_name: None,
+        snippets: search_results.clone(),
+        reserved_tokens: 1000,
+    };
+
+    let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+        (PromptPriority::High, Box::new(EngineerPreamble {})),
+        (PromptPriority::Low, Box::new(RepositoryContext {})),
+    ];
+    let chain = PromptChain::new(args, templates);
+
+    let prompt = chain.generate(true)?;
+    println!("{:?}", prompt);
+
     const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
     const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
 
@@ -183,7 +162,7 @@ pub fn generate_content_prompt(
     let range = range.to_offset(buffer);
 
     // General Preamble
-    if let Some(language_name) = language_name {
+    if let Some(language_name) = language_name.clone() {
         prompts.push(format!("You're an expert {language_name} engineer.\n"));
     } else {
         prompts.push("You're an expert engineer.\n".to_string());
@@ -297,7 +276,7 @@ pub fn generate_content_prompt(
         }
     }
 
-    prompts.join("\n")
+    anyhow::Ok(prompts.join("\n"))
 }
 
 #[cfg(test)]