move truncation to parsing step leveraging the EmbeddingProvider trait

KCaverly created

Change summary

crates/semantic_index/src/embedding.rs            | 78 ++++++++--------
crates/semantic_index/src/parsing.rs              |  4 
crates/semantic_index/src/semantic_index_tests.rs |  4 
3 files changed, 45 insertions(+), 41 deletions(-)

Detailed changes

crates/semantic_index/src/embedding.rs 🔗

@@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
     fn count_tokens(&self, span: &str) -> usize;
     fn should_truncate(&self, span: &str) -> bool;
-    // fn truncate(&self, span: &str) -> Result<&str>;
+    fn truncate(&self, span: &str) -> String;
 }
 
 pub struct DummyEmbeddings {}
@@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings {
 
     fn should_truncate(&self, span: &str) -> bool {
         self.count_tokens(span) > OPENAI_INPUT_LIMIT
+    }
 
-        // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        // let Ok(output) = {
-        //     if tokens.len() > OPENAI_INPUT_LIMIT {
-        //         tokens.truncate(OPENAI_INPUT_LIMIT);
-        //         OPENAI_BPE_TOKENIZER.decode(tokens)
-        //     } else {
-        //         Ok(span)
-        //     }
-        // };
+    fn truncate(&self, span: &str) -> String {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+            tokens.truncate(OPENAI_INPUT_LIMIT);
+            OPENAI_BPE_TOKENIZER
+                .decode(tokens)
+                .ok()
+                .unwrap_or_else(|| span.to_string())
+        } else {
+            span.to_string()
+        };
+
+        output
     }
 }
 
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
-    fn truncate(span: String) -> String {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
-        if tokens.len() > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
-            if result.is_ok() {
-                let transformed = result.unwrap();
-                return transformed;
-            }
-        }
-
-        span
-    }
-
     async fn send_request(
         &self,
         api_key: &str,
@@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         self.count_tokens(span) > OPENAI_INPUT_LIMIT
     }
 
+    fn truncate(&self, span: &str) -> String {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+            tokens.truncate(OPENAI_INPUT_LIMIT);
+            OPENAI_BPE_TOKENIZER
+                .decode(tokens)
+                .ok()
+                .unwrap_or_else(|| span.to_string())
+        } else {
+            span.to_string()
+        };
+
+        output
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
@@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                     self.executor.timer(delay_duration).await;
                 }
                 _ => {
-                    // TODO: Move this to parsing step
-                    // Only truncate if it hasnt been truncated before
-                    if !truncated {
-                        for span in spans.iter_mut() {
-                            *span = Self::truncate(span.clone());
-                        }
-                        truncated = true;
-                    } else {
-                        // If failing once already truncated, log the error and break the loop
-                        let mut body = String::new();
-                        response.body_mut().read_to_string(&mut body).await?;
-                        return Err(anyhow!(
-                            "open ai bad request: {:?} {:?}",
-                            &response.status(),
-                            body
-                        ));
-                    }
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "open ai bad request: {:?} {:?}",
+                        &response.status(),
+                        body
+                    ));
                 }
             }
         }

crates/semantic_index/src/parsing.rs 🔗

@@ -73,6 +73,7 @@ impl CodeContextRetriever {
         sha1.update(&document_span);
 
         let token_count = self.embedding_provider.count_tokens(&document_span);
+        let document_span = self.embedding_provider.truncate(&document_span);
 
         Ok(vec![Document {
             range: 0..content.len(),
@@ -93,6 +94,7 @@ impl CodeContextRetriever {
         sha1.update(&document_span);
 
         let token_count = self.embedding_provider.count_tokens(&document_span);
+        let document_span = self.embedding_provider.truncate(&document_span);
 
         Ok(vec![Document {
             range: 0..content.len(),
@@ -182,6 +184,8 @@ impl CodeContextRetriever {
                 .replace("item", &document.content);
 
             let token_count = self.embedding_provider.count_tokens(&document_content);
+            let document_content = self.embedding_provider.truncate(&document_content);
+
             document.content = document_content;
             document.token_count = token_count;
         }

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         false
     }
 
+    fn truncate(&self, span: &str) -> String {
+        span.to_string()
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);