moved semantic index to use embeddings queue to batch and managed for atomic database writes

KCaverly and Max created

Co-authored-by: Max <max@zed.dev>

Change summary

crates/semantic_index/src/embedding_queue.rs      |  25 +
crates/semantic_index/src/semantic_index.rs       | 238 ++--------------
crates/semantic_index/src/semantic_index_tests.rs |  14 
3 files changed, 55 insertions(+), 222 deletions(-)

Detailed changes

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -1,10 +1,8 @@
-use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
-
-use gpui::AppContext;
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use gpui::executor::Background;
 use parking_lot::Mutex;
 use smol::channel;
-
-use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
 
 #[derive(Clone)]
 pub struct FileToEmbed {
@@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed {
 pub struct EmbeddingQueue {
     embedding_provider: Arc<dyn EmbeddingProvider>,
     pending_batch: Vec<FileToEmbedFragment>,
+    executor: Arc<Background>,
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
@@ -49,10 +48,11 @@ pub struct FileToEmbedFragment {
 }
 
 impl EmbeddingQueue {
-    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
         let (finished_files_tx, finished_files_rx) = channel::unbounded();
         Self {
             embedding_provider,
+            executor,
             pending_batch: Vec::new(),
             pending_batch_token_count: 0,
             finished_files_tx,
@@ -60,7 +60,12 @@ impl EmbeddingQueue {
         }
     }
 
-    pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
+    pub fn push(&mut self, file: FileToEmbed) {
+        if file.documents.is_empty() {
+            self.finished_files_tx.try_send(file).unwrap();
+            return;
+        }
+
         let file = Arc::new(Mutex::new(file));
 
         self.pending_batch.push(FileToEmbedFragment {
@@ -73,7 +78,7 @@ impl EmbeddingQueue {
             let next_token_count = self.pending_batch_token_count + document.token_count;
             if next_token_count > self.embedding_provider.max_tokens_per_batch() {
                 let range_end = fragment_range.end;
-                self.flush(cx);
+                self.flush();
                 self.pending_batch.push(FileToEmbedFragment {
                     file: file.clone(),
                     document_range: range_end..range_end,
@@ -86,7 +91,7 @@ impl EmbeddingQueue {
         }
     }
 
-    pub fn flush(&mut self, cx: &mut AppContext) {
+    pub fn flush(&mut self) {
         let batch = mem::take(&mut self.pending_batch);
         self.pending_batch_token_count = 0;
         if batch.is_empty() {
@@ -95,7 +100,7 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
-        cx.background().spawn(async move {
+        self.executor.spawn(async move {
             let mut spans = Vec::new();
             for fragment in &batch {
                 let file = fragment.file.lock();

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1,5 +1,6 @@
 mod db;
 mod embedding;
+mod embedding_queue;
 mod parsing;
 pub mod semantic_index_settings;
 
@@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
 use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
+use embedding_queue::{EmbeddingQueue, FileToEmbed};
 use futures::{channel::oneshot, Future};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
@@ -23,7 +25,6 @@ use smol::channel;
 use std::{
     cmp::Ordering,
     collections::{BTreeMap, HashMap},
-    mem,
     ops::Range,
     path::{Path, PathBuf},
     sync::{Arc, Weak},
@@ -38,7 +39,6 @@ use util::{
 use workspace::WorkspaceCreated;
 
 const SEMANTIC_INDEX_VERSION: usize = 7;
-const EMBEDDINGS_BATCH_SIZE: usize = 80;
 const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600);
 
 pub fn init(
@@ -106,9 +106,8 @@ pub struct SemanticIndex {
     language_registry: Arc<LanguageRegistry>,
     db_update_tx: channel::Sender<DbOperation>,
     parsing_files_tx: channel::Sender<PendingFile>,
+    _embedding_task: Task<()>,
     _db_update_task: Task<()>,
-    _embed_batch_tasks: Vec<Task<()>>,
-    _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
@@ -128,7 +127,7 @@ struct ChangedPathInfo {
 }
 
 #[derive(Clone)]
-struct JobHandle {
+pub struct JobHandle {
     /// The outer Arc is here to count the clones of a JobHandle instance;
     /// when the last handle to a given job is dropped, we decrement a counter (just once).
     tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@@ -230,17 +229,6 @@ enum DbOperation {
     },
 }
 
-enum EmbeddingJob {
-    Enqueue {
-        worktree_id: i64,
-        path: PathBuf,
-        mtime: SystemTime,
-        documents: Vec<Document>,
-        job_handle: JobHandle,
-    },
-    Flush,
-}
-
 impl SemanticIndex {
     pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
         if cx.has_global::<ModelHandle<Self>>() {
@@ -287,52 +275,35 @@ impl SemanticIndex {
                 }
             });
 
-            // Group documents into batches and send them to the embedding provider.
-            let (embed_batch_tx, embed_batch_rx) =
-                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
-            let mut _embed_batch_tasks = Vec::new();
-            for _ in 0..cx.background().num_cpus() {
-                let embed_batch_rx = embed_batch_rx.clone();
-                _embed_batch_tasks.push(cx.background().spawn({
-                    let db_update_tx = db_update_tx.clone();
-                    let embedding_provider = embedding_provider.clone();
-                    async move {
-                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
-                            Self::compute_embeddings_for_batch(
-                                embeddings_queue,
-                                &embedding_provider,
-                                &db_update_tx,
-                            )
-                            .await;
-                        }
+            let embedding_queue =
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+            let _embedding_task = cx.background().spawn({
+                let embedded_files = embedding_queue.finished_files();
+                let db_update_tx = db_update_tx.clone();
+                async move {
+                    while let Ok(file) = embedded_files.recv().await {
+                        db_update_tx
+                            .try_send(DbOperation::InsertFile {
+                                worktree_id: file.worktree_id,
+                                documents: file.documents,
+                                path: file.path,
+                                mtime: file.mtime,
+                                job_handle: file.job_handle,
+                            })
+                            .ok();
                     }
-                }));
-            }
-
-            // Group documents into batches and send them to the embedding provider.
-            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
-            let _batch_files_task = cx.background().spawn(async move {
-                let mut queue_len = 0;
-                let mut embeddings_queue = vec![];
-                while let Ok(job) = batch_files_rx.recv().await {
-                    Self::enqueue_documents_to_embed(
-                        job,
-                        &mut queue_len,
-                        &mut embeddings_queue,
-                        &embed_batch_tx,
-                    );
                 }
             });
 
             // Parse files into embeddable documents.
             let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
+            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
             let mut _parsing_files_tasks = Vec::new();
             for _ in 0..cx.background().num_cpus() {
                 let fs = fs.clone();
                 let parsing_files_rx = parsing_files_rx.clone();
-                let batch_files_tx = batch_files_tx.clone();
-                let db_update_tx = db_update_tx.clone();
                 let embedding_provider = embedding_provider.clone();
+                let embedding_queue = embedding_queue.clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
                     let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
                     while let Ok(pending_file) = parsing_files_rx.recv().await {
@@ -340,9 +311,8 @@ impl SemanticIndex {
                             &fs,
                             pending_file,
                             &mut retriever,
-                            &batch_files_tx,
+                            &embedding_queue,
                             &parsing_files_rx,
-                            &db_update_tx,
                         )
                         .await;
                     }
@@ -361,8 +331,7 @@ impl SemanticIndex {
                 db_update_tx,
                 parsing_files_tx,
                 _db_update_task,
-                _embed_batch_tasks,
-                _batch_files_task,
+                _embedding_task,
                 _parsing_files_tasks,
                 projects: HashMap::new(),
             }
@@ -403,136 +372,12 @@ impl SemanticIndex {
         }
     }
 
-    async fn compute_embeddings_for_batch(
-        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
-        embedding_provider: &Arc<dyn EmbeddingProvider>,
-        db_update_tx: &channel::Sender<DbOperation>,
-    ) {
-        let mut batch_documents = vec![];
-        for (_, documents, _, _, _) in embeddings_queue.iter() {
-            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
-        }
-
-        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
-            log::trace!(
-                "created {} embeddings for {} files",
-                embeddings.len(),
-                embeddings_queue.len(),
-            );
-
-            let mut i = 0;
-            let mut j = 0;
-
-            for embedding in embeddings.iter() {
-                while embeddings_queue[i].1.len() == j {
-                    i += 1;
-                    j = 0;
-                }
-
-                embeddings_queue[i].1[j].embedding = embedding.to_owned();
-                j += 1;
-            }
-
-            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
-                db_update_tx
-                    .send(DbOperation::InsertFile {
-                        worktree_id,
-                        documents,
-                        path,
-                        mtime,
-                        job_handle,
-                    })
-                    .await
-                    .unwrap();
-            }
-        } else {
-            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
-            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
-                db_update_tx
-                    .send(DbOperation::InsertFile {
-                        worktree_id,
-                        documents: vec![],
-                        path,
-                        mtime,
-                        job_handle,
-                    })
-                    .await
-                    .unwrap();
-            }
-        }
-    }
-
-    fn enqueue_documents_to_embed(
-        job: EmbeddingJob,
-        queue_len: &mut usize,
-        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
-        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
-    ) {
-        // Handle edge case where individual file has more documents than max batch size
-        let should_flush = match job {
-            EmbeddingJob::Enqueue {
-                documents,
-                worktree_id,
-                path,
-                mtime,
-                job_handle,
-            } => {
-                // If documents is greater than embeddings batch size, recursively batch existing rows.
-                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
-                    let first_job = EmbeddingJob::Enqueue {
-                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
-                        worktree_id,
-                        path: path.clone(),
-                        mtime,
-                        job_handle: job_handle.clone(),
-                    };
-
-                    Self::enqueue_documents_to_embed(
-                        first_job,
-                        queue_len,
-                        embeddings_queue,
-                        embed_batch_tx,
-                    );
-
-                    let second_job = EmbeddingJob::Enqueue {
-                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
-                        worktree_id,
-                        path: path.clone(),
-                        mtime,
-                        job_handle: job_handle.clone(),
-                    };
-
-                    Self::enqueue_documents_to_embed(
-                        second_job,
-                        queue_len,
-                        embeddings_queue,
-                        embed_batch_tx,
-                    );
-                    return;
-                } else {
-                    *queue_len += &documents.len();
-                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
-                    *queue_len >= EMBEDDINGS_BATCH_SIZE
-                }
-            }
-            EmbeddingJob::Flush => true,
-        };
-
-        if should_flush {
-            embed_batch_tx
-                .try_send(mem::take(embeddings_queue))
-                .unwrap();
-            *queue_len = 0;
-        }
-    }
-
     async fn parse_file(
         fs: &Arc<dyn Fs>,
         pending_file: PendingFile,
         retriever: &mut CodeContextRetriever,
-        batch_files_tx: &channel::Sender<EmbeddingJob>,
+        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
         parsing_files_rx: &channel::Receiver<PendingFile>,
-        db_update_tx: &channel::Sender<DbOperation>,
     ) {
         let Some(language) = pending_file.language else {
             return;
@@ -549,33 +394,18 @@ impl SemanticIndex {
                     documents.len()
                 );
 
-                if documents.len() == 0 {
-                    db_update_tx
-                        .send(DbOperation::InsertFile {
-                            worktree_id: pending_file.worktree_db_id,
-                            documents,
-                            path: pending_file.relative_path,
-                            mtime: pending_file.modified_time,
-                            job_handle: pending_file.job_handle,
-                        })
-                        .await
-                        .unwrap();
-                } else {
-                    batch_files_tx
-                        .try_send(EmbeddingJob::Enqueue {
-                            worktree_id: pending_file.worktree_db_id,
-                            path: pending_file.relative_path,
-                            mtime: pending_file.modified_time,
-                            job_handle: pending_file.job_handle,
-                            documents,
-                        })
-                        .unwrap();
-                }
+                embedding_queue.lock().push(FileToEmbed {
+                    worktree_id: pending_file.worktree_db_id,
+                    path: pending_file.relative_path,
+                    mtime: pending_file.modified_time,
+                    job_handle: pending_file.job_handle,
+                    documents,
+                });
             }
         }
 
         if parsing_files_rx.len() == 0 {
-            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+            embedding_queue.lock().flush();
         }
     }
 
@@ -881,7 +711,7 @@ impl SemanticIndex {
             let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
 
             let phrase_embedding = embedding_provider
-                .embed_batch(vec![&phrase])
+                .embed_batch(vec![phrase])
                 .await?
                 .into_iter()
                 .next()

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
         .collect::<Vec<_>>();
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let mut queue = EmbeddingQueue::new(embedding_provider.clone());
 
-    let finished_files = cx.update(|cx| {
-        for file in &files {
-            queue.push(file.clone(), cx);
-        }
-        queue.flush(cx);
-        queue.finished_files()
-    });
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
+    for file in &files {
+        queue.push(file.clone());
+    }
+    queue.flush();
 
     cx.foreground().run_until_parked();
+    let finished_files = queue.finished_files();
     let mut embedded_files: Vec<_> = files
         .iter()
         .map(|_| finished_files.try_recv().expect("no finished file"))