diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 05798c3f5d1a3ed2a88739fd3a5a911ed708d560..f792406c8b9367baf825af50b01f077f119b3860 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; use std::time::Instant; +use crate::models::LanguageModel; + #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -66,6 +68,7 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { + fn base_model(&self) -> Box; fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 8061a2ca6b1c21d853db152fa431d9a2674c8740..9df5547da1fcdd0f0473f2b5371ff85a046b46a3 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -3,10 +3,42 @@ use std::time::Instant; use crate::{ completion::CompletionRequest, embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, }; use async_trait::async_trait; use serde::Serialize; +pub struct DummyLanguageModel {} + +impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(1000) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: crate::models::TruncationDirection, + ) -> anyhow::Result { + let truncated = match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[..length] + .iter() + .collect::(), + }; + + anyhow::Ok(truncated) + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } +} + #[derive(Serialize)] pub struct DummyCompletionRequest { pub name: String, @@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {} #[async_trait] impl EmbeddingProvider for DummyEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 35398394dc8fdbef09f74dab5a82307d7b4b0aaf..ed028177f68d96bf84576c7dbcba7d4bb4888907 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -27,6 +29,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + OpenAIEmbeddingProvider { + model, client, executor, rate_limit_count_rx, @@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn is_authenticated(&self) -> bool { OPENAI_API_KEY.as_ref().is_some() } diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs index 42523f3df48951d8674b33409105f8d802fd6c25..6e306c80b905865c011c9064934827085ca126d6 100644 --- a/crates/ai/src/providers/open_ai/model.rs +++ b/crates/ai/src/providers/open_ai/model.rs @@ -4,6 +4,7 @@ use util::ResultExt; use crate::models::{LanguageModel, TruncationDirection}; +#[derive(Clone)] pub struct OpenAILanguageModel { name: String, bpe: Option, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 6842ce5c5d45711c45883c348028463e8429aa64..43779f5b6ccf23cf18fad232a6a4db2f33ce0b2c 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,8 +4,11 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::dummy::DummyEmbeddingProvider; +use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::LanguageModel, +}; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext}; @@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider { #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true }