move embedding truncation to base model

KCaverly created

Change summary

crates/ai/src/embedding.rs                        |  1 
crates/ai/src/providers/dummy.rs                  | 11 ++---
crates/ai/src/providers/open_ai/embedding.rs      | 30 +++++++--------
crates/semantic_index/src/parsing.rs              | 33 ++++++++++++++--
crates/semantic_index/src/semantic_index_tests.rs |  9 +---
5 files changed, 49 insertions(+), 35 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send {
     fn is_authenticated(&self) -> bool;
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
-    fn truncate(&self, span: &str) -> (String, usize);
     fn rate_limit_expiration(&self) -> Option<Instant>;
 }
 

crates/ai/src/providers/dummy.rs 🔗

@@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel {
         length: usize,
         direction: crate::models::TruncationDirection,
     ) -> anyhow::Result<String> {
+        if content.len() < length {
+            return anyhow::Ok(content.to_string());
+        }
+
         let truncated = match direction {
             TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
                 .iter()
@@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider {
     fn max_tokens_per_batch(&self) -> usize {
         8190
     }
-
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let truncated = span.chars().collect::<Vec<char>>()[..8190]
-            .iter()
-            .collect::<String>();
-        (truncated, 8190)
-    }
 }

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage {
     total_tokens: usize,
 }
 
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
 impl OpenAIEmbeddingProvider {
     pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
@@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
     fn rate_limit_expiration(&self) -> Option<Instant> {
         *self.rate_limit_count_rx.borrow()
     }
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            OPENAI_BPE_TOKENIZER
-                .decode(tokens.clone())
-                .ok()
-                .unwrap_or_else(|| span.to_string())
-        } else {
-            span.to_string()
-        };
-
-        (output, tokens.len())
-    }
+    // fn truncate(&self, span: &str) -> (String, usize) {
+    //     let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+    //     let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+    //         tokens.truncate(OPENAI_INPUT_LIMIT);
+    //         OPENAI_BPE_TOKENIZER
+    //             .decode(tokens.clone())
+    //             .ok()
+    //             .unwrap_or_else(|| span.to_string())
+    //     } else {
+    //         span.to_string()
+    //     };
+
+    //     (output, tokens.len())
+    // }
 
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];

crates/semantic_index/src/parsing.rs 🔗

@@ -1,4 +1,7 @@
-use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::{
+    embedding::{Embedding, EmbeddingProvider},
+    models::TruncationDirection,
+};
 use anyhow::{anyhow, Result};
 use language::{Grammar, Language};
 use rusqlite::{
@@ -108,7 +111,14 @@ impl CodeContextRetriever {
             .replace("<language>", language_name.as_ref())
             .replace("<item>", &content);
         let digest = SpanDigest::from(document_span.as_str());
-        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
         Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
@@ -131,7 +141,15 @@ impl CodeContextRetriever {
             )
             .replace("<item>", &content);
         let digest = SpanDigest::from(document_span.as_str());
-        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
         Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
@@ -222,8 +240,13 @@ impl CodeContextRetriever {
                 .replace("<language>", language_name.as_ref())
                 .replace("item", &span.content);
 
-            let (document_content, token_count) =
-                self.embedding_provider.truncate(&document_content);
+            let model = self.embedding_provider.base_model();
+            let document_content = model.truncate(
+                &document_content,
+                model.capacity()?,
+                TruncationDirection::End,
+            )?;
+            let token_count = model.count_tokens(&document_content)?;
 
             span.content = document_content;
             span.token_count = token_count;

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
     fn is_authenticated(&self) -> bool {
         true
     }
-    fn truncate(&self, span: &str) -> (String, usize) {
-        (span.to_string(), 1)
-    }
-
     fn max_tokens_per_batch(&self) -> usize {
-        200
+        1000
     }
 
     fn rate_limit_expiration(&self) -> Option<Instant> {
@@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+
+        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
     }
 }