From 3712794e561b7aa6068ee3de0fc411d5cb311566 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:47:28 +0200 Subject: [PATCH] move OpenAILanguageModel to providers folder --- crates/ai/src/ai.rs | 1 + crates/ai/src/models.rs | 55 ----------------------- crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/mod.rs | 2 + crates/ai/src/providers/open_ai/model.rs | 56 ++++++++++++++++++++++++ crates/assistant/src/prompts.rs | 3 +- 6 files changed, 62 insertions(+), 56 deletions(-) create mode 100644 crates/ai/src/providers/mod.rs create mode 100644 crates/ai/src/providers/open_ai/mod.rs create mode 100644 crates/ai/src/providers/open_ai/model.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index c0b78b74cf2a24dded65671303ccade5cc4261e3..a3ae2fcf7ffb5075c70b607f1cdf34279cd063a3 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -2,3 +2,4 @@ pub mod completion; pub mod embedding; pub mod models; pub mod prompts; +pub mod providers; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index afb8496156f6521eb3125b6a0ba6d703d5d0fe50..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -1,7 +1,3 @@ -use anyhow::anyhow; -use tiktoken_rs::CoreBPE; -use util::ResultExt; - pub enum TruncationDirection { Start, End, @@ -18,54 +14,3 @@ pub trait LanguageModel { ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } - -pub struct OpenAILanguageModel { - name: String, - bpe: Option, -} - -impl OpenAILanguageModel { - pub fn load(model_name: &str) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); - OpenAILanguageModel { - name: model_name.to_string(), - bpe, - } - } -} - -impl LanguageModel for OpenAILanguageModel { - fn name(&self) -> String { - self.name.clone() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - anyhow::Ok(bpe.encode_with_special_tokens(content).len()) - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - match direction { - TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), - TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), - } - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d8489e187b26c092cd2043b42f7339b4a43d794 --- /dev/null +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -0,0 +1,2 @@ +pub mod model; +pub use model::OpenAILanguageModel; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..42523f3df48951d8674b33409105f8d802fd6c25 --- /dev/null +++ b/crates/ai/src/providers/open_ai/model.rs @@ -0,0 +1,56 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 8fff232fdbf9358de9a97892bdd58d070871f64a..25af023c40072cbc56ba2865658544999e0fa7c7 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,10 @@ -use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::models::LanguageModel; use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::prompts::file_context::FileContext; use ai::prompts::generate::GenerateInlineContent; use ai::prompts::preamble::EngineerPreamble; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::providers::open_ai::OpenAILanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range;