From a874a09b7e3b30696dad650bc997342fd8a53a61 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 17 Oct 2023 16:21:03 -0400 Subject: [PATCH] added openai language model tokenizer and LanguageModel trait --- crates/ai/src/ai.rs | 1 + crates/ai/src/models.rs | 49 ++++++++++++++++++++++++++ crates/ai/src/templates/base.rs | 54 ++++++++++++----------------- crates/ai/src/templates/preamble.rs | 42 +++++++++++++++------- 4 files changed, 102 insertions(+), 44 deletions(-) create mode 100644 crates/ai/src/models.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 04e9e14536c16d80de133940db6723349e8d2371..f168c157934f6b70be775f7e17e9ba27ef9b3103 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,3 +1,4 @@ pub mod completion; pub mod embedding; +pub mod models; pub mod templates; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..4fe96d44f33f10ad1e6ee8572a8cceb02fca8fd4 --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,49 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} + +struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: String) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(&model_name).log_err(); + OpenAILanguageModel { + name: model_name, + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + bpe.decode(tokens[..length].to_vec()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index 74a4c424ae93b46da34d3f5493f6e2363b31c2f5..b5f9da3586f7793e601ca8f5bf7a3158da5949c8 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -1,17 +1,11 @@ -use std::fmt::Write; -use std::{cmp::Reverse, sync::Arc}; +use std::cmp::Reverse; +use std::sync::Arc; use util::ResultExt; +use crate::models::LanguageModel; use crate::templates::repository_context::PromptCodeSnippet; -pub trait LanguageModel { - fn name(&self) -> String; - fn count_tokens(&self, content: &str) -> usize; - fn truncate(&self, content: &str, length: usize) -> String; - fn capacity(&self) -> usize; -} - pub(crate) enum PromptFileType { Text, Code, @@ -73,7 +67,7 @@ impl PromptChain { pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { // Argsort based on Prompt Priority let seperator = "\n"; - let seperator_tokens = self.args.model.count_tokens(seperator); + let seperator_tokens = self.args.model.count_tokens(seperator)?; let mut sorted_indices = (0..self.templates.len()).collect::>(); sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); @@ -81,7 +75,7 @@ impl PromptChain { // If Truncate let mut tokens_outstanding = if truncate { - Some(self.args.model.capacity() - self.args.reserved_tokens) + Some(self.args.model.capacity()? - self.args.reserved_tokens) } else { None }; @@ -111,7 +105,7 @@ impl PromptChain { } let full_prompt = prompts.join(seperator); - let total_token_count = self.args.model.count_tokens(&full_prompt); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; anyhow::Ok((prompts.join(seperator), total_token_count)) } } @@ -131,10 +125,10 @@ pub(crate) mod tests { ) -> anyhow::Result<(String, usize)> { let mut content = "This is a test prompt template".to_string(); - let mut token_count = args.model.count_tokens(&content); + let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length); + content = args.model.truncate(&content, max_token_length)?; token_count = max_token_length; } } @@ -152,10 +146,10 @@ pub(crate) mod tests { ) -> anyhow::Result<(String, usize)> { let mut content = "This is a low priority test prompt template".to_string(); - let mut token_count = args.model.count_tokens(&content); + let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length); + content = args.model.truncate(&content, max_token_length)?; token_count = max_token_length; } } @@ -169,26 +163,22 @@ pub(crate) mod tests { capacity: usize, } - impl DummyLanguageModel { - fn set_capacity(&mut self, capacity: usize) { - self.capacity = capacity - } - } - impl LanguageModel for DummyLanguageModel { fn name(&self) -> String { "dummy".to_string() } - fn count_tokens(&self, content: &str) -> usize { - content.chars().collect::>().len() + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) } - fn truncate(&self, content: &str, length: usize) -> String { - content.chars().collect::>()[..length] - .into_iter() - .collect::() + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[..length] + .into_iter() + .collect::(), + ) } - fn capacity(&self) -> usize { - self.capacity + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) } } @@ -215,7 +205,7 @@ pub(crate) mod tests { .to_string() ); - assert_eq!(model.count_tokens(&prompt), token_count); + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); // Testing with Truncation Off // Should ignore capacity and return all prompts @@ -242,7 +232,7 @@ pub(crate) mod tests { .to_string() ); - assert_eq!(model.count_tokens(&prompt), token_count); + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); // Testing with Truncation Off // Should ignore capacity and return all prompts diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs index b1d33f885ea493f9488894154fe262e7ce177edc..f395dbf8beeabde2a703214cc0426900908be990 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/templates/preamble.rs @@ -4,31 +4,49 @@ use std::fmt::Write; struct EngineerPreamble {} impl PromptTemplate for EngineerPreamble { - fn generate(&self, args: &PromptArguments, max_token_length: Option) -> String { - let mut prompt = String::new(); + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); match args.get_file_type() { PromptFileType::Code => { - writeln!( - prompt, + prompts.push(format!( "You are an expert {} engineer.", args.language_name.clone().unwrap_or("".to_string()) - ) - .unwrap(); + )); } PromptFileType::Text => { - writeln!(prompt, "You are an expert engineer.").unwrap(); + prompts.push("You are an expert engineer.".to_string()); } } if let Some(project_name) = args.project_name.clone() { - writeln!( - prompt, + prompts.push(format!( "You are currently working inside the '{project_name}' in Zed the code editor." - ) - .unwrap(); + )); } - prompt + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } } }