move OpenAILanguageModel to providers folder

KCaverly created

Change summary

crates/ai/src/ai.rs                      |  1 
crates/ai/src/models.rs                  | 55 -------------------------
crates/ai/src/providers/mod.rs           |  1 
crates/ai/src/providers/open_ai/mod.rs   |  2 
crates/ai/src/providers/open_ai/model.rs | 56 ++++++++++++++++++++++++++
crates/assistant/src/prompts.rs          |  3 
6 files changed, 62 insertions(+), 56 deletions(-)

Detailed changes

crates/ai/src/ai.rs 🔗

@@ -2,3 +2,4 @@ pub mod completion;
 pub mod embedding;
 pub mod models;
 pub mod prompts;
+pub mod providers;

crates/ai/src/models.rs 🔗

@@ -1,7 +1,3 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
-
 pub enum TruncationDirection {
     Start,
     End,
@@ -18,54 +14,3 @@ pub trait LanguageModel {
     ) -> anyhow::Result<String>;
     fn capacity(&self) -> anyhow::Result<usize>;
 }
-
-pub struct OpenAILanguageModel {
-    name: String,
-    bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
-    pub fn load(model_name: &str) -> Self {
-        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
-        OpenAILanguageModel {
-            name: model_name.to_string(),
-            bpe,
-        }
-    }
-}
-
-impl LanguageModel for OpenAILanguageModel {
-    fn name(&self) -> String {
-        self.name.clone()
-    }
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-        if let Some(bpe) = &self.bpe {
-            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn truncate(
-        &self,
-        content: &str,
-        length: usize,
-        direction: TruncationDirection,
-    ) -> anyhow::Result<String> {
-        if let Some(bpe) = &self.bpe {
-            let tokens = bpe.encode_with_special_tokens(content);
-            if tokens.len() > length {
-                match direction {
-                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
-                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
-                }
-            } else {
-                bpe.decode(tokens)
-            }
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn capacity(&self) -> anyhow::Result<usize> {
-        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
-    }
-}

crates/ai/src/providers/open_ai/model.rs 🔗

@@ -0,0 +1,56 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+pub struct OpenAILanguageModel {
+    name: String,
+    bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+    pub fn load(model_name: &str) -> Self {
+        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+        OpenAILanguageModel {
+            name: model_name.to_string(),
+            bpe,
+        }
+    }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+    fn name(&self) -> String {
+        self.name.clone()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        if let Some(bpe) = &self.bpe {
+            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                match direction {
+                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+                }
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+    }
+}

crates/assistant/src/prompts.rs 🔗

@@ -1,9 +1,10 @@
-use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::models::LanguageModel;
 use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
 use ai::prompts::file_context::FileContext;
 use ai::prompts::generate::GenerateInlineContent;
 use ai::prompts::preamble::EngineerPreamble;
 use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::providers::open_ai::OpenAILanguageModel;
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
 use std::cmp::{self, Reverse};
 use std::ops::Range;