add should_truncate to embedding providers

KCaverly created

Change summary

crates/semantic_index/src/embedding.rs            | 19 +++++++++++++++++
crates/semantic_index/src/semantic_index_tests.rs |  4 +++
2 files changed, 23 insertions(+)

Detailed changes

crates/semantic_index/src/embedding.rs 🔗

@@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage {
 pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
     fn count_tokens(&self, span: &str) -> usize;
+    fn should_truncate(&self, span: &str) -> bool;
     // fn truncate(&self, span: &str) -> Result<&str>;
 }
 
@@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings {
         let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
         tokens.len()
     }
+
+    fn should_truncate(&self, span: &str) -> bool {
+        self.count_tokens(span) > OPENAI_INPUT_LIMIT
+
+        // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        // let Ok(output) = {
+        //     if tokens.len() > OPENAI_INPUT_LIMIT {
+        //         tokens.truncate(OPENAI_INPUT_LIMIT);
+        //         OPENAI_BPE_TOKENIZER.decode(tokens)
+        //     } else {
+        //         Ok(span)
+        //     }
+        // };
+    }
 }
 
 const OPENAI_INPUT_LIMIT: usize = 8190;
@@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         tokens.len()
     }
 
+    fn should_truncate(&self, span: &str) -> bool {
+        self.count_tokens(span) > OPENAI_INPUT_LIMIT
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         span.len()
     }
 
+    fn should_truncate(&self, span: &str) -> bool {
+        false
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);