move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch"

KCaverly and Max created

Co-authored-by: Max <max@zed.dev>

Change summary

crates/semantic_index/src/embedding.rs            |  47 +---
crates/semantic_index/src/embedding_queue.rs      | 140 +++++++++++++++
crates/semantic_index/src/parsing.rs              |  10 
crates/semantic_index/src/semantic_index_tests.rs | 154 ++++++++++++----
crates/util/src/util.rs                           |  35 ++-
5 files changed, 295 insertions(+), 91 deletions(-)

Detailed changes

crates/semantic_index/src/embedding.rs 🔗

@@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage {
 
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
-    fn count_tokens(&self, span: &str) -> usize;
-    fn should_truncate(&self, span: &str) -> bool;
-    fn truncate(&self, span: &str) -> String;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
+    fn max_tokens_per_batch(&self) -> usize;
+    fn truncate(&self, span: &str) -> (String, usize);
 }
 
 pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
         let dummy_vec = vec![0.32 as f32; 1536];
         return Ok(vec![dummy_vec; spans.len()]);
     }
 
-    fn count_tokens(&self, span: &str) -> usize {
-        // For Dummy Providers, we are going to use OpenAI tokenization for ease
-        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        tokens.len()
+    fn max_tokens_per_batch(&self) -> usize {
+        OPENAI_INPUT_LIMIT
     }
 
-    fn should_truncate(&self, span: &str) -> bool {
-        self.count_tokens(span) > OPENAI_INPUT_LIMIT
-    }
-
-    fn truncate(&self, span: &str) -> String {
+    fn truncate(&self, span: &str) -> (String, usize) {
         let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+        let token_count = tokens.len();
+        let output = if token_count > OPENAI_INPUT_LIMIT {
             tokens.truncate(OPENAI_INPUT_LIMIT);
             OPENAI_BPE_TOKENIZER
                 .decode(tokens)
@@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings {
             span.to_string()
         };
 
-        output
+        (output, token_count)
     }
 }
 
@@ -125,19 +119,14 @@ impl OpenAIEmbeddings {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
-    fn count_tokens(&self, span: &str) -> usize {
-        // For Dummy Providers, we are going to use OpenAI tokenization for ease
-        let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        tokens.len()
-    }
-
-    fn should_truncate(&self, span: &str) -> bool {
-        self.count_tokens(span) > OPENAI_INPUT_LIMIT
+    fn max_tokens_per_batch(&self) -> usize {
+        OPENAI_INPUT_LIMIT
     }
 
-    fn truncate(&self, span: &str) -> String {
+    fn truncate(&self, span: &str) -> (String, usize) {
         let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+        let token_count = tokens.len();
+        let output = if token_count > OPENAI_INPUT_LIMIT {
             tokens.truncate(OPENAI_INPUT_LIMIT);
             OPENAI_BPE_TOKENIZER
                 .decode(tokens)
@@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             span.to_string()
         };
 
-        output
+        (output, token_count)
     }
 
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
@@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
 
         let mut request_number = 0;
         let mut request_timeout: u64 = 10;
-        let mut truncated = false;
         let mut response: Response<AsyncBody>;
-        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
         while request_number < MAX_RETRIES {
             response = self
                 .send_request(

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -0,0 +1,140 @@
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
+
+use gpui::AppContext;
+use parking_lot::Mutex;
+use smol::channel;
+
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+
+#[derive(Clone)]
+pub struct FileToEmbed {
+    pub worktree_id: i64,
+    pub path: PathBuf,
+    pub mtime: SystemTime,
+    pub documents: Vec<Document>,
+    pub job_handle: JobHandle,
+}
+
+impl std::fmt::Debug for FileToEmbed {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("FileToEmbed")
+            .field("worktree_id", &self.worktree_id)
+            .field("path", &self.path)
+            .field("mtime", &self.mtime)
+            .field("document", &self.documents)
+            .finish_non_exhaustive()
+    }
+}
+
+impl PartialEq for FileToEmbed {
+    fn eq(&self, other: &Self) -> bool {
+        self.worktree_id == other.worktree_id
+            && self.path == other.path
+            && self.mtime == other.mtime
+            && self.documents == other.documents
+    }
+}
+
+pub struct EmbeddingQueue {
+    embedding_provider: Arc<dyn EmbeddingProvider>,
+    pending_batch: Vec<FileToEmbedFragment>,
+    pending_batch_token_count: usize,
+    finished_files_tx: channel::Sender<FileToEmbed>,
+    finished_files_rx: channel::Receiver<FileToEmbed>,
+}
+
+pub struct FileToEmbedFragment {
+    file: Arc<Mutex<FileToEmbed>>,
+    document_range: Range<usize>,
+}
+
+impl EmbeddingQueue {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
+        let (finished_files_tx, finished_files_rx) = channel::unbounded();
+        Self {
+            embedding_provider,
+            pending_batch: Vec::new(),
+            pending_batch_token_count: 0,
+            finished_files_tx,
+            finished_files_rx,
+        }
+    }
+
+    pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
+        let file = Arc::new(Mutex::new(file));
+
+        self.pending_batch.push(FileToEmbedFragment {
+            file: file.clone(),
+            document_range: 0..0,
+        });
+
+        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+        for (ix, document) in file.lock().documents.iter().enumerate() {
+            let next_token_count = self.pending_batch_token_count + document.token_count;
+            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
+                let range_end = fragment_range.end;
+                self.flush(cx);
+                self.pending_batch.push(FileToEmbedFragment {
+                    file: file.clone(),
+                    document_range: range_end..range_end,
+                });
+                fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+            }
+
+            fragment_range.end = ix + 1;
+            self.pending_batch_token_count += document.token_count;
+        }
+    }
+
+    pub fn flush(&mut self, cx: &mut AppContext) {
+        let batch = mem::take(&mut self.pending_batch);
+        self.pending_batch_token_count = 0;
+        if batch.is_empty() {
+            return;
+        }
+
+        let finished_files_tx = self.finished_files_tx.clone();
+        let embedding_provider = self.embedding_provider.clone();
+        cx.background().spawn(async move {
+            let mut spans = Vec::new();
+            for fragment in &batch {
+                let file = fragment.file.lock();
+                spans.extend(
+                    file.documents[fragment.document_range.clone()]
+                        .iter()
+                        .map(|d| d.content.clone()),
+                );
+            }
+
+            match embedding_provider.embed_batch(spans).await {
+                Ok(embeddings) => {
+                    let mut embeddings = embeddings.into_iter();
+                    for fragment in batch {
+                        for document in
+                            &mut fragment.file.lock().documents[fragment.document_range.clone()]
+                        {
+                            if let Some(embedding) = embeddings.next() {
+                                document.embedding = embedding;
+                            } else {
+                                //
+                                log::error!("number of embeddings returned different from number of documents");
+                            }
+                        }
+
+                        if let Some(file) = Arc::into_inner(fragment.file) {
+                            finished_files_tx.try_send(file.into_inner()).unwrap();
+                        }
+                    }
+                }
+                Err(error) => {
+                    log::error!("{:?}", error);
+                }
+            }
+        })
+        .detach();
+    }
+
+    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
+        self.finished_files_rx.clone()
+    }
+}

crates/semantic_index/src/parsing.rs 🔗

@@ -72,8 +72,7 @@ impl CodeContextRetriever {
         let mut sha1 = Sha1::new();
         sha1.update(&document_span);
 
-        let token_count = self.embedding_provider.count_tokens(&document_span);
-        let document_span = self.embedding_provider.truncate(&document_span);
+        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
 
         Ok(vec![Document {
             range: 0..content.len(),
@@ -93,8 +92,7 @@ impl CodeContextRetriever {
         let mut sha1 = Sha1::new();
         sha1.update(&document_span);
 
-        let token_count = self.embedding_provider.count_tokens(&document_span);
-        let document_span = self.embedding_provider.truncate(&document_span);
+        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
 
         Ok(vec![Document {
             range: 0..content.len(),
@@ -183,8 +181,8 @@ impl CodeContextRetriever {
                 .replace("<language>", language_name.as_ref())
                 .replace("item", &document.content);
 
-            let token_count = self.embedding_provider.count_tokens(&document_content);
-            let document_content = self.embedding_provider.truncate(&document_content);
+            let (document_content, token_count) =
+                self.embedding_provider.truncate(&document_content);
 
             document.content = document_content;
             document.token_count = token_count;

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,14 +1,16 @@
 use crate::{
     db::dot,
     embedding::{DummyEmbeddings, EmbeddingProvider},
+    embedding_queue::EmbeddingQueue,
     parsing::{subtract_ranges, CodeContextRetriever, Document},
     semantic_index_settings::SemanticIndexSettings,
-    SearchResult, SemanticIndex,
+    FileToEmbed, JobHandle, SearchResult, SemanticIndex,
 };
 use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
+use parking_lot::Mutex;
 use pretty_assertions::assert_eq;
 use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
 use rand::{rngs::StdRng, Rng};
@@ -20,8 +22,10 @@ use std::{
         atomic::{self, AtomicUsize},
         Arc,
     },
+    time::SystemTime,
 };
 use unindent::Unindent;
+use util::RandomCharIter;
 
 #[ctor::ctor]
 fn init_logger() {
@@ -32,11 +36,7 @@ fn init_logger() {
 
 #[gpui::test]
 async fn test_semantic_index(cx: &mut TestAppContext) {
-    cx.update(|cx| {
-        cx.set_global(SettingsStore::test(cx));
-        settings::register::<SemanticIndexSettings>(cx);
-        settings::register::<ProjectSettings>(cx);
-    });
+    init_test(cx);
 
     let fs = FakeFs::new(cx.background());
     fs.insert_tree(
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     let db_path = db_dir.path().join("db.sqlite");
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let store = SemanticIndex::new(
+    let semantic_index = SemanticIndex::new(
         fs.clone(),
         db_path,
         embedding_provider.clone(),
@@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 
-    let _ = store
+    let _ = semantic_index
         .update(cx, |store, cx| {
             store.initialize_project(project.clone(), cx)
         })
         .await;
 
-    let (file_count, outstanding_file_count) = store
+    let (file_count, outstanding_file_count) = semantic_index
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
@@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     cx.foreground().run_until_parked();
     assert_eq!(*outstanding_file_count.borrow(), 0);
 
-    let search_results = store
+    let search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     // Test Include Files Functonality
     let include_files = vec![PathMatcher::new("*.rs").unwrap()];
     let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
-    let rust_only_search_results = store
+    let rust_only_search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
         cx,
     );
 
-    let no_rust_search_results = store
+    let no_rust_search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     cx.foreground().run_until_parked();
 
     let prev_embedding_count = embedding_provider.embedding_count();
-    let (file_count, outstanding_file_count) = store
+    let (file_count, outstanding_file_count) = semantic_index
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
@@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
+    let (outstanding_job_count, _) = postage::watch::channel_with(0);
+    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
+
+    let files = (1..=3)
+        .map(|file_ix| FileToEmbed {
+            worktree_id: 5,
+            path: format!("path-{file_ix}").into(),
+            mtime: SystemTime::now(),
+            documents: (0..rng.gen_range(4..22))
+                .map(|document_ix| {
+                    let content_len = rng.gen_range(10..100);
+                    Document {
+                        range: 0..10,
+                        embedding: Vec::new(),
+                        name: format!("document {document_ix}"),
+                        content: RandomCharIter::new(&mut rng)
+                            .with_simple_text()
+                            .take(content_len)
+                            .collect(),
+                        sha1: rng.gen(),
+                        token_count: rng.gen_range(10..30),
+                    }
+                })
+                .collect(),
+            job_handle: JobHandle::new(&outstanding_job_count),
+        })
+        .collect::<Vec<_>>();
+
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone());
+
+    let finished_files = cx.update(|cx| {
+        for file in &files {
+            queue.push(file.clone(), cx);
+        }
+        queue.flush(cx);
+        queue.finished_files()
+    });
+
+    cx.foreground().run_until_parked();
+    let mut embedded_files: Vec<_> = files
+        .iter()
+        .map(|_| finished_files.try_recv().expect("no finished file"))
+        .collect();
+
+    let expected_files: Vec<_> = files
+        .iter()
+        .map(|file| {
+            let mut file = file.clone();
+            for doc in &mut file.documents {
+                doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
+            }
+            file
+        })
+        .collect();
+
+    embedded_files.sort_by_key(|f| f.path.clone());
+
+    assert_eq!(embedded_files, expected_files);
+}
+
 #[track_caller]
 fn assert_search_results(
     actual: &[SearchResult],
@@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider {
     fn embedding_count(&self) -> usize {
         self.embedding_count.load(atomic::Ordering::SeqCst)
     }
+
+    fn embed_sync(&self, span: &str) -> Vec<f32> {
+        let mut result = vec![1.0; 26];
+        for letter in span.chars() {
+            let letter = letter.to_ascii_lowercase();
+            if letter as u32 >= 'a' as u32 {
+                let ix = (letter as u32) - ('a' as u32);
+                if ix < 26 {
+                    result[ix as usize] += 1.0;
+                }
+            }
+        }
+
+        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+        for x in &mut result {
+            *x /= norm;
+        }
+
+        result
+    }
 }
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
-    fn count_tokens(&self, span: &str) -> usize {
-        span.len()
-    }
-
-    fn should_truncate(&self, span: &str) -> bool {
-        false
+    fn truncate(&self, span: &str) -> (String, usize) {
+        (span.to_string(), 1)
     }
 
-    fn truncate(&self, span: &str) -> String {
-        span.to_string()
+    fn max_tokens_per_batch(&self) -> usize {
+        200
     }
 
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-        Ok(spans
-            .iter()
-            .map(|span| {
-                let mut result = vec![1.0; 26];
-                for letter in span.chars() {
-                    let letter = letter.to_ascii_lowercase();
-                    if letter as u32 >= 'a' as u32 {
-                        let ix = (letter as u32) - ('a' as u32);
-                        if ix < 26 {
-                            result[ix as usize] += 1.0;
-                        }
-                    }
-                }
-
-                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
-                for x in &mut result {
-                    *x /= norm;
-                }
-
-                result
-            })
-            .collect())
+        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
     }
 }
 
@@ -1704,3 +1762,11 @@ fn test_subtract_ranges() {
 
     assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
 }
+
+fn init_test(cx: &mut TestAppContext) {
+    cx.update(|cx| {
+        cx.set_global(SettingsStore::test(cx));
+        settings::register::<SemanticIndexSettings>(cx);
+        settings::register::<ProjectSettings>(cx);
+    });
+}

crates/util/src/util.rs 🔗

@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
     Defer(Some(f))
 }
 
-pub struct RandomCharIter<T: Rng>(T);
+pub struct RandomCharIter<T: Rng> {
+    rng: T,
+    simple_text: bool,
+}
 
 impl<T: Rng> RandomCharIter<T> {
     pub fn new(rng: T) -> Self {
-        Self(rng)
+        Self {
+            rng,
+            simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
+        }
+    }
+
+    pub fn with_simple_text(mut self) -> Self {
+        self.simple_text = true;
+        self
     }
 }
 
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
     type Item = char;
 
     fn next(&mut self) -> Option<Self::Item> {
-        if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
-            return if self.0.gen_range(0..100) < 5 {
+        if self.simple_text {
+            return if self.rng.gen_range(0..100) < 5 {
                 Some('\n')
             } else {
-                Some(self.0.gen_range(b'a'..b'z' + 1).into())
+                Some(self.rng.gen_range(b'a'..b'z' + 1).into())
             };
         }
 
-        match self.0.gen_range(0..100) {
+        match self.rng.gen_range(0..100) {
             // whitespace
-            0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
+            0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
             // two-byte greek letters
-            20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
+            20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
             // // three-byte characters
-            33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
+            33..=45 => ['✋', '✅', '❌', '❎', '⭐']
+                .choose(&mut self.rng)
+                .copied(),
             // // four-byte characters
-            46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
+            46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
             // ascii letters
-            _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
+            _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
         }
     }
 }