add base model to EmbeddingProvider, not yet leveraged for truncation

KCaverly created

Change summary

crates/ai/src/embedding.rs                        |  3 +
crates/ai/src/providers/dummy.rs                  | 35 +++++++++++++++++
crates/ai/src/providers/open_ai/embedding.rs      | 10 ++++
crates/ai/src/providers/open_ai/model.rs          |  1 
crates/semantic_index/src/semantic_index_tests.rs | 10 +++
5 files changed, 57 insertions(+), 2 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
 use std::time::Instant;
 
+use crate::models::LanguageModel;
+
 #[derive(Debug, PartialEq, Clone)]
 pub struct Embedding(pub Vec<f32>);
 
@@ -66,6 +68,7 @@ impl Embedding {
 
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
     fn is_authenticated(&self) -> bool;
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;

crates/ai/src/providers/dummy.rs 🔗

@@ -3,10 +3,42 @@ use std::time::Instant;
 use crate::{
     completion::CompletionRequest,
     embedding::{Embedding, EmbeddingProvider},
+    models::{LanguageModel, TruncationDirection},
 };
 use async_trait::async_trait;
 use serde::Serialize;
 
+pub struct DummyLanguageModel {}
+
+impl LanguageModel for DummyLanguageModel {
+    fn name(&self) -> String {
+        "dummy".to_string()
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(1000)
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: crate::models::TruncationDirection,
+    ) -> anyhow::Result<String> {
+        let truncated = match direction {
+            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+                .iter()
+                .collect::<String>(),
+            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
+                .iter()
+                .collect::<String>(),
+        };
+
+        anyhow::Ok(truncated)
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+    }
+}
+
 #[derive(Serialize)]
 pub struct DummyCompletionRequest {
     pub name: String,
@@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        Box::new(DummyLanguageModel {})
+    }
     fn is_authenticated(&self) -> bool {
         true
     }

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
 
 use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
 
 lazy_static! {
     static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
@@ -27,6 +29,7 @@ lazy_static! {
 
 #[derive(Clone)]
 pub struct OpenAIEmbeddingProvider {
+    model: OpenAILanguageModel,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
+        let model = OpenAILanguageModel::load("text-embedding-ada-002");
+
         OpenAIEmbeddingProvider {
+            model,
             client,
             executor,
             rate_limit_count_rx,
@@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
     fn is_authenticated(&self) -> bool {
         OPENAI_API_KEY.as_ref().is_some()
     }

crates/ai/src/providers/open_ai/model.rs 🔗

@@ -4,6 +4,7 @@ use util::ResultExt;
 
 use crate::models::{LanguageModel, TruncationDirection};
 
+#[derive(Clone)]
 pub struct OpenAILanguageModel {
     name: String,
     bpe: Option<CoreBPE>,

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -4,8 +4,11 @@ use crate::{
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
-use ai::embedding::{Embedding, EmbeddingProvider};
-use ai::providers::dummy::DummyEmbeddingProvider;
+use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
+use ai::{
+    embedding::{Embedding, EmbeddingProvider},
+    models::LanguageModel,
+};
 use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{executor::Deterministic, Task, TestAppContext};
@@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider {
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        Box::new(DummyLanguageModel {})
+    }
     fn is_authenticated(&self) -> bool {
         true
     }