crates/ai/src/ai.rs 🔗
@@ -2,3 +2,4 @@ pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
+pub mod providers;
KCaverly created
crates/ai/src/ai.rs | 1
crates/ai/src/models.rs | 55 -------------------------
crates/ai/src/providers/mod.rs | 1
crates/ai/src/providers/open_ai/mod.rs | 2
crates/ai/src/providers/open_ai/model.rs | 56 ++++++++++++++++++++++++++
crates/assistant/src/prompts.rs | 3
6 files changed, 62 insertions(+), 56 deletions(-)
@@ -2,3 +2,4 @@ pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
+pub mod providers;
@@ -1,7 +1,3 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
-
pub enum TruncationDirection {
Start,
End,
@@ -18,54 +14,3 @@ pub trait LanguageModel {
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
-
-pub struct OpenAILanguageModel {
- name: String,
- bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
- pub fn load(model_name: &str) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
- OpenAILanguageModel {
- name: model_name.to_string(),
- bpe,
- }
- }
-}
-
-impl LanguageModel for OpenAILanguageModel {
- fn name(&self) -> String {
- self.name.clone()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- if let Some(bpe) = &self.bpe {
- anyhow::Ok(bpe.encode_with_special_tokens(content).len())
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- match direction {
- TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
- TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
- }
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
- }
-}
@@ -0,0 +1 @@
+pub mod open_ai;
@@ -0,0 +1,2 @@
+pub mod model;
+pub use model::OpenAILanguageModel;
@@ -0,0 +1,56 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -1,9 +1,10 @@
-use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;