added token count to documents during parsing

KCaverly created

Change summary

crates/semantic_index/src/embedding.rs            | 14 +++++++
crates/semantic_index/src/parsing.rs              | 19 +++++++++-
crates/semantic_index/src/semantic_index.rs       |  3 +
crates/semantic_index/src/semantic_index_tests.rs | 30 +++++++++++-----
4 files changed, 54 insertions(+), 12 deletions(-)

Detailed changes

crates/semantic_index/src/embedding.rs 🔗

@@ -54,6 +54,8 @@ struct OpenAIEmbeddingUsage {
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
+    fn count_tokens(&self, span: &str) -> usize;
+    // fn truncate(&self, span: &str) -> Result<&str>;
 }
 
 pub struct DummyEmbeddings {}
@@ -66,6 +68,12 @@ impl EmbeddingProvider for DummyEmbeddings {
         let dummy_vec = vec![0.32 as f32; 1536];
         return Ok(vec![dummy_vec; spans.len()]);
     }
+
+    fn count_tokens(&self, span: &str) -> usize {
+        // For Dummy Providers, we are going to use OpenAI tokenization for ease
+        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        tokens.len()
+    }
 }
 
 const OPENAI_INPUT_LIMIT: usize = 8190;
@@ -111,6 +119,12 @@ impl OpenAIEmbeddings {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
+    fn count_tokens(&self, span: &str) -> usize {
+        // For Dummy Providers, we are going to use OpenAI tokenization for ease
+        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        tokens.len()
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;

crates/semantic_index/src/parsing.rs 🔗

@@ -1,3 +1,4 @@
+use crate::embedding::EmbeddingProvider;
 use anyhow::{anyhow, Ok, Result};
 use language::{Grammar, Language};
 use sha1::{Digest, Sha1};
@@ -17,6 +18,7 @@ pub struct Document {
     pub content: String,
     pub embedding: Vec<f32>,
     pub sha1: [u8; 20],
+    pub token_count: usize,
 }
 
 const CODE_CONTEXT_TEMPLATE: &str =
@@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 pub struct CodeContextRetriever {
     pub parser: Parser,
     pub cursor: QueryCursor,
+    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 }
 
 // Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@@ -47,10 +50,11 @@ pub struct CodeContextMatch {
 }
 
 impl CodeContextRetriever {
-    pub fn new() -> Self {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
         Self {
             parser: Parser::new(),
             cursor: QueryCursor::new(),
+            embedding_provider,
         }
     }
 
@@ -68,12 +72,15 @@ impl CodeContextRetriever {
         let mut sha1 = Sha1::new();
         sha1.update(&document_span);
 
+        let token_count = self.embedding_provider.count_tokens(&document_span);
+
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
             embedding: Vec::new(),
             name: language_name.to_string(),
             sha1: sha1.finalize().into(),
+            token_count,
         }])
     }
 
@@ -85,12 +92,15 @@ impl CodeContextRetriever {
         let mut sha1 = Sha1::new();
         sha1.update(&document_span);
 
+        let token_count = self.embedding_provider.count_tokens(&document_span);
+
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
             embedding: Vec::new(),
             name: "Markdown".to_string(),
             sha1: sha1.finalize().into(),
+            token_count,
         }])
     }
 
@@ -166,10 +176,14 @@ impl CodeContextRetriever {
 
         let mut documents = self.parse_file(content, language)?;
         for document in &mut documents {
-            document.content = CODE_CONTEXT_TEMPLATE
+            let document_content = CODE_CONTEXT_TEMPLATE
                 .replace("<path>", relative_path.to_string_lossy().as_ref())
                 .replace("<language>", language_name.as_ref())
                 .replace("item", &document.content);
+
+            let token_count = self.embedding_provider.count_tokens(&document_content);
+            document.content = document_content;
+            document.token_count = token_count;
         }
         Ok(documents)
     }
@@ -272,6 +286,7 @@ impl CodeContextRetriever {
                 range: item_range.clone(),
                 embedding: vec![],
                 sha1: sha1.finalize().into(),
+                token_count: 0,
             })
         }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -332,8 +332,9 @@ impl SemanticIndex {
                 let parsing_files_rx = parsing_files_rx.clone();
                 let batch_files_tx = batch_files_tx.clone();
                 let db_update_tx = db_update_tx.clone();
+                let embedding_provider = embedding_provider.clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
-                    let mut retriever = CodeContextRetriever::new();
+                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
                     while let Ok(pending_file) = parsing_files_rx.recv().await {
                         Self::parse_file(
                             &fs,

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     db::dot,
-    embedding::EmbeddingProvider,
+    embedding::{DummyEmbeddings, EmbeddingProvider},
     parsing::{subtract_ranges, CodeContextRetriever, Document},
     semantic_index_settings::SemanticIndexSettings,
     SearchResult, SemanticIndex,
@@ -227,7 +227,8 @@ fn assert_search_results(
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
         /// A doc comment
@@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() {
 #[gpui::test]
 async fn test_code_context_retrieval_json() {
     let language = json_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         {
@@ -397,7 +399,8 @@ fn assert_documents_eq(
 #[gpui::test]
 async fn test_code_context_retrieval_javascript() {
     let language = js_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
         /* globals importScripts, backend */
@@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() {
 #[gpui::test]
 async fn test_code_context_retrieval_lua() {
     let language = lua_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         -- Creates a new class
@@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() {
 #[gpui::test]
 async fn test_code_context_retrieval_elixir() {
     let language = elixir_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         defmodule File.Stream do
@@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() {
 #[gpui::test]
 async fn test_code_context_retrieval_cpp() {
     let language = cpp_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
     /**
@@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() {
 #[gpui::test]
 async fn test_code_context_retrieval_ruby() {
     let language = ruby_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         # This concern is inspired by "sudo mode" on GitHub. It
@@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() {
 #[gpui::test]
 async fn test_code_context_retrieval_php() {
     let language = php_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         <?php
@@ -1216,6 +1224,10 @@ impl FakeEmbeddingProvider {
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
+    fn count_tokens(&self, span: &str) -> usize {
+        span.len()
+    }
+
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);