Detailed changes
@@ -9,16 +9,16 @@ pub trait LanguageModel {
fn capacity(&self) -> anyhow::Result<usize>;
}
-struct OpenAILanguageModel {
+pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
- pub fn load(model_name: String) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err();
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
- name: model_name,
+ name: model_name.to_string(),
bpe,
}
}
@@ -1,149 +0,0 @@
-use gpui::{AsyncAppContext, ModelHandle};
-use language::{Anchor, Buffer};
-use std::{fmt::Write, ops::Range, path::PathBuf};
-
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
- let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot.text_for_range(range.clone()).collect::<String>();
-
- let language_name = buffer
- .language()
- .and_then(|language| Some(language.name().to_string()));
-
- let file_path = buffer
- .file()
- .and_then(|file| Some(file.path().to_path_buf()));
-
- (content, language_name, file_path)
- });
-
- PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- }
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .and_then(|path| Some(path.to_string_lossy().to_string()))
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
-
-enum PromptFileType {
- Text,
- Code,
-}
-
-#[derive(Default)]
-struct PromptArguments {
- pub language_name: Option<String>,
- pub project_name: Option<String>,
- pub snippets: Vec<PromptCodeSnippet>,
- pub model_name: String,
-}
-
-impl PromptArguments {
- pub fn get_file_type(&self) -> PromptFileType {
- if self
- .language_name
- .as_ref()
- .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
- .unwrap_or(true)
- {
- PromptFileType::Code
- } else {
- PromptFileType::Text
- }
- }
-}
-
-trait PromptTemplate {
- fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String;
-}
-
-struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
- fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
- let mut prompt = String::new();
-
- match args.get_file_type() {
- PromptFileType::Code => {
- writeln!(
- prompt,
- "You are an expert {} engineer.",
- args.language_name.unwrap_or("".to_string())
- )
- .unwrap();
- }
- PromptFileType::Text => {
- writeln!(prompt, "You are an expert engineer.").unwrap();
- }
- }
-
- if let Some(project_name) = args.project_name {
- writeln!(
- prompt,
- "You are currently working inside the '{project_name}' in Zed the code editor."
- )
- .unwrap();
- }
-
- prompt
- }
-}
-
-struct RepositorySnippets {}
-
-impl PromptTemplate for RepositorySnippets {
- fn generate(args: PromptArguments, max_token_length: Option<usize>) -> String {
- const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
- let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
- let mut prompt = String::new();
-
- if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(args.model_name.as_str()) {
- let default_token_count =
- tiktoken_rs::model::get_context_size(args.model_name.as_str());
- let mut remaining_token_count = max_token_length.unwrap_or(default_token_count);
-
- for snippet in args.snippets {
- let mut snippet_prompt = template.to_string();
- let content = snippet.to_string();
- writeln!(snippet_prompt, "{content}").unwrap();
-
- let token_count = encoding
- .encode_with_special_tokens(snippet_prompt.as_str())
- .len();
- if token_count <= remaining_token_count {
- if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
- writeln!(prompt, "{snippet_prompt}").unwrap();
- remaining_token_count -= token_count;
- template = "";
- }
- } else {
- break;
- }
- }
- }
-
- prompt
- }
-}
@@ -1,7 +1,7 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
-struct EngineerPreamble {}
+pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
@@ -14,8 +14,8 @@ impl PromptTemplate for EngineerPreamble {
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
- "You are an expert {} engineer.",
- args.language_name.clone().unwrap_or("".to_string())
+ "You are an expert {}engineer.",
+ args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
@@ -1,8 +1,11 @@
+use crate::templates::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, ModelHandle};
use language::{Anchor, Buffer};
+#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
@@ -17,7 +20,7 @@ impl PromptCodeSnippet {
let language_name = buffer
.language()
- .and_then(|language| Some(language.name().to_string()));
+ .and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
@@ -47,3 +50,45 @@ impl ToString for PromptCodeSnippet {
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let mut template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+ let mut prompt = String::new();
+
+ let mut remaining_tokens = max_token_length.clone();
+ let seperator_token_length = args.model.count_tokens("\n")?;
+ for snippet in &args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = args.model.count_tokens(&snippet_prompt)?;
+ if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+ if let Some(tokens_left) = remaining_tokens {
+ if tokens_left >= token_count {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+ {
+ Some(tokens_left - token_count - seperator_token_length)
+ } else {
+ Some(0)
+ };
+ }
+ } else {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ }
+ }
+ }
+
+ let total_token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, total_token_count))
+ }
+}
@@ -1,12 +1,15 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
- prompts::{generate_content_prompt, PromptCodeSnippet},
+ prompts::generate_content_prompt,
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
-use ai::completion::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+ completion::{
+ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+ },
+ templates::repository_context::PromptCodeSnippet,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -668,14 +671,7 @@ impl AssistantPanel {
let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new();
for result in search_results.await {
- snippets.push(PromptCodeSnippet::new(result, &cx));
-
- // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
- // buffer
- // .snapshot()
- // .text_for_range(result.range)
- // .collect::<String>()
- // }));
+ snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
}
snippets
});
@@ -717,7 +713,8 @@ impl AssistantPanel {
}
cx.spawn(|_, mut cx| async move {
- let prompt = prompt.await;
+ // I Don't know if we want to return a ? here.
+ let prompt = prompt.await?;
messages.push(RequestMessage {
role: Role::User,
@@ -729,6 +726,7 @@ impl AssistantPanel {
stream: true,
};
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+ anyhow::Ok(())
})
.detach();
}
@@ -1,61 +1,15 @@
use crate::codegen::CodegenKind;
-use gpui::AsyncAppContext;
+use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::preamble::EngineerPreamble;
+use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use semantic_index::SearchResult;
use std::cmp::{self, Reverse};
use std::fmt::Write;
use std::ops::Range;
-use std::path::PathBuf;
+use std::sync::Arc;
use tiktoken_rs::ChatCompletionRequestMessage;
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
- let (content, language_name, file_path) =
- search_result.buffer.read_with(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot
- .text_for_range(search_result.range.clone())
- .collect::<String>();
-
- let language_name = buffer
- .language()
- .and_then(|language| Some(language.name().to_string()));
-
- let file_path = buffer
- .file()
- .and_then(|file| Some(file.path().to_path_buf()));
-
- (content, language_name, file_path)
- });
-
- PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- }
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .and_then(|path| Some(path.to_string_lossy().to_string()))
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
-
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
@@ -175,7 +129,32 @@ pub fn generate_content_prompt(
kind: CodegenKind,
search_results: Vec<PromptCodeSnippet>,
model: &str,
-) -> String {
+) -> anyhow::Result<String> {
+ // Using new Prompt Templates
+ let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
+ let lang_name = if let Some(language_name) = language_name {
+ Some(language_name.to_string())
+ } else {
+ None
+ };
+
+ let args = PromptArguments {
+ model: openai_model,
+ language_name: lang_name.clone(),
+ project_name: None,
+ snippets: search_results.clone(),
+ reserved_tokens: 1000,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (PromptPriority::High, Box::new(EngineerPreamble {})),
+ (PromptPriority::Low, Box::new(RepositoryContext {})),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let prompt = chain.generate(true)?;
+ println!("{:?}", prompt);
+
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
@@ -183,7 +162,7 @@ pub fn generate_content_prompt(
let range = range.to_offset(buffer);
// General Preamble
- if let Some(language_name) = language_name {
+ if let Some(language_name) = language_name.clone() {
prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else {
prompts.push("You're an expert engineer.\n".to_string());
@@ -297,7 +276,7 @@ pub fn generate_content_prompt(
}
}
- prompts.join("\n")
+ anyhow::Ok(prompts.join("\n"))
}
#[cfg(test)]