diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index f792406c8b9367baf825af50b01f077f119b3860..4e67f44cae72b9b7778e05895964f42f8f78e535 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send { fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; - fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 9df5547da1fcdd0f0473f2b5371ff85a046b46a3..7eef16111d9789975128dee1e2183094908e84f2 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel { length: usize, direction: crate::models::TruncationDirection, ) -> anyhow::Result { + if content.len() < length { + return anyhow::Ok(content.to_string()); + } + let truncated = match direction { TruncationDirection::End => content.chars().collect::>()[..length] .iter() @@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider { fn max_tokens_per_batch(&self) -> usize { 8190 } - - fn truncate(&self, span: &str) -> (String, usize) { - let truncated = span.chars().collect::>()[..8190] - .iter() - .collect::(); - (truncated, 8190) - } } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index ed028177f68d96bf84576c7dbcba7d4bb4888907..3689cb36f41d34ca51d39478eba14ebef21b5c00 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage { total_tokens: usize, } -const OPENAI_INPUT_LIMIT: usize = 8190; - impl OpenAIEmbeddingProvider { pub fn new(client: Arc, executor: Arc) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); @@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } + // fn truncate(&self, span: &str) -> (String, usize) { + // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let output = if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER + // .decode(tokens.clone()) + // .ok() + // .unwrap_or_else(|| span.to_string()) + // } else { + // span.to_string() + // }; + + // (output, tokens.len()) + // } async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index f9b8bac9a484bfae48c62683ee096e2e49420622..cb15ca453b2c0640739bd44a95482ca527b8d91b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,4 +1,7 @@ -use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::TruncationDirection, +}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ @@ -108,7 +111,14 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -131,7 +141,15 @@ impl CodeContextRetriever { ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -222,8 +240,13 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &span.content); - let (document_content, token_count) = - self.embedding_provider.truncate(&document_content); + let model = self.embedding_provider.base_model(); + let document_content = model.truncate( + &document_content, + model.capacity()?, + TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_content)?; span.content = document_content; span.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 43779f5b6ccf23cf18fad232a6a4db2f33ce0b2c..002dee33e33c9a253ad9c2baf51c6c4dcdb6f2a4 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { fn is_authenticated(&self) -> bool { true } - fn truncate(&self, span: &str) -> (String, usize) { - (span.to_string(), 1) - } - fn max_tokens_per_batch(&self) -> usize { - 200 + 1000 } fn rate_limit_expiration(&self) -> Option { @@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } }