enabled batching for embedding calls

KCaverly created

Change summary

crates/vector_store/src/vector_store.rs | 157 ++++++++++++++++++++------
1 file changed, 120 insertions(+), 37 deletions(-)

Detailed changes

crates/vector_store/src/vector_store.rs 🔗

@@ -22,7 +22,7 @@ use std::{
     collections::HashMap,
     path::{Path, PathBuf},
     sync::Arc,
-    time::SystemTime,
+    time::{Instant, SystemTime},
 };
 use tree_sitter::{Parser, QueryCursor};
 use util::{
@@ -34,8 +34,9 @@ use util::{
 use workspace::{Workspace, WorkspaceCreated};
 
 const REINDEXING_DELAY: u64 = 30;
+const EMBEDDINGS_BATCH_SIZE: usize = 25;
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct Document {
     pub offset: usize,
     pub name: String,
@@ -110,7 +111,7 @@ pub fn init(
     .detach();
 }
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct IndexedFile {
     path: PathBuf,
     mtime: SystemTime,
@@ -126,6 +127,7 @@ pub struct VectorStore {
     paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
     _db_update_task: Task<()>,
     _paths_update_task: Task<()>,
+    _embeddings_update_task: Task<()>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
 
@@ -184,7 +186,14 @@ impl VectorStore {
             .await?;
 
         Ok(cx.add_model(|cx| {
+            // paths_tx -> embeddings_tx -> db_update_tx
+
             let (db_update_tx, db_update_rx) = channel::unbounded();
+            let (paths_tx, paths_rx) =
+                channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
+            let (embeddings_tx, embeddings_rx) =
+                channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
+
             let _db_update_task = cx.background().spawn(async move {
                 while let Ok(job) = db_update_rx.recv().await {
                     match job {
@@ -192,11 +201,9 @@ impl VectorStore {
                             worktree_id,
                             indexed_file,
                         } => {
-                            log::info!("Inserting File: {:?}", &indexed_file.path);
                             db.insert_file(worktree_id, indexed_file).log_err();
                         }
                         DbWrite::Delete { worktree_id, path } => {
-                            log::info!("Deleting File: {:?}", &path);
                             db.delete_file(worktree_id, path).log_err();
                         }
                         DbWrite::FindOrCreateWorktree { path, sender } => {
@@ -207,35 +214,116 @@ impl VectorStore {
                 }
             });
 
-            let (paths_tx, paths_rx) =
-                channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
+            async fn embed_batch(
+                embeddings_queue: Vec<(i64, IndexedFile, Vec<String>)>,
+                embedding_provider: &Arc<dyn EmbeddingProvider>,
+                db_update_tx: channel::Sender<DbWrite>,
+            ) -> Result<()> {
+                let mut embeddings_queue = embeddings_queue.clone();
+
+                let mut document_spans = vec![];
+                for (_, _, document_span) in embeddings_queue.clone().into_iter() {
+                    document_spans.extend(document_span);
+                }
+
+                let mut embeddings = embedding_provider
+                    .embed_batch(document_spans.iter().map(|x| &**x).collect())
+                    .await?;
+
+                // This assumes the embeddings are returned in order
+                let t0 = Instant::now();
+                let mut i = 0;
+                let mut j = 0;
+                while let Some(embedding) = embeddings.pop() {
+                    // This has to accomodate for multiple indexed_files in a row without documents
+                    while embeddings_queue[i].1.documents.len() == j {
+                        i += 1;
+                        j = 0;
+                    }
+
+                    embeddings_queue[i].1.documents[j].embedding = embedding;
+                    j += 1;
+                }
+
+                for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
+                    // TODO: Update this so it doesnt panic
+                    for document in indexed_file.documents.iter() {
+                        assert!(
+                            document.embedding.len() > 0,
+                            "Document Embedding not Complete"
+                        );
+                    }
+
+                    db_update_tx
+                        .send(DbWrite::InsertFile {
+                            worktree_id,
+                            indexed_file,
+                        })
+                        .await
+                        .unwrap();
+                }
+
+                anyhow::Ok(())
+            }
 
-            let fs_clone = fs.clone();
-            let db_update_tx_clone = db_update_tx.clone();
             let embedding_provider_clone = embedding_provider.clone();
 
+            let db_update_tx_clone = db_update_tx.clone();
+            let _embeddings_update_task = cx.background().spawn(async move {
+                let mut queue_len = 0;
+                let mut embeddings_queue = vec![];
+                let mut request_count = 0;
+                while let Ok((worktree_id, indexed_file, document_spans)) =
+                    embeddings_rx.recv().await
+                {
+                    queue_len += &document_spans.len();
+                    embeddings_queue.push((worktree_id, indexed_file, document_spans));
+
+                    if queue_len >= EMBEDDINGS_BATCH_SIZE {
+                        let _ = embed_batch(
+                            embeddings_queue,
+                            &embedding_provider_clone,
+                            db_update_tx_clone.clone(),
+                        )
+                        .await;
+
+                        embeddings_queue = vec![];
+                        queue_len = 0;
+
+                        request_count += 1;
+                    }
+                }
+
+                if queue_len > 0 {
+                    let _ = embed_batch(
+                        embeddings_queue,
+                        &embedding_provider_clone,
+                        db_update_tx_clone.clone(),
+                    )
+                    .await;
+                    request_count += 1;
+                }
+            });
+
+            let fs_clone = fs.clone();
+
             let _paths_update_task = cx.background().spawn(async move {
                 let mut parser = Parser::new();
                 let mut cursor = QueryCursor::new();
                 while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await {
-                    log::info!("Parsing File: {:?}", &file_path);
-                    if let Some(indexed_file) = Self::index_file(
+                    if let Some((indexed_file, document_spans)) = Self::index_file(
                         &mut cursor,
                         &mut parser,
-                        embedding_provider_clone.as_ref(),
                         &fs_clone,
                         language,
-                        file_path,
+                        file_path.clone(),
                         mtime,
                     )
                     .await
                     .log_err()
                     {
-                        db_update_tx_clone
-                            .try_send(DbWrite::InsertFile {
-                                worktree_id,
-                                indexed_file,
-                            })
+                        embeddings_tx
+                            .try_send((worktree_id, indexed_file, document_spans))
                             .unwrap();
                     }
                 }
@@ -251,6 +339,7 @@ impl VectorStore {
                 projects: HashMap::new(),
                 _db_update_task,
                 _paths_update_task,
+                _embeddings_update_task,
             }
         }))
     }
@@ -258,12 +347,11 @@ impl VectorStore {
     async fn index_file(
         cursor: &mut QueryCursor,
         parser: &mut Parser,
-        embedding_provider: &dyn EmbeddingProvider,
         fs: &Arc<dyn Fs>,
         language: Arc<Language>,
         file_path: PathBuf,
         mtime: SystemTime,
-    ) -> Result<IndexedFile> {
+    ) -> Result<(IndexedFile, Vec<String>)> {
         let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
         let embedding_config = grammar
             .embedding_config
@@ -298,7 +386,7 @@ impl VectorStore {
                 if let Some((item, name)) =
                     content.get(item_range.clone()).zip(content.get(name_range))
                 {
-                    context_spans.push(item);
+                    context_spans.push(item.to_string());
                     documents.push(Document {
                         name: name.to_string(),
                         offset: item_range.start,
@@ -308,18 +396,14 @@ impl VectorStore {
             }
         }
 
-        if !documents.is_empty() {
-            let embeddings = embedding_provider.embed_batch(context_spans).await?;
-            for (document, embedding) in documents.iter_mut().zip(embeddings) {
-                document.embedding = embedding;
-            }
-        }
-
-        return Ok(IndexedFile {
-            path: file_path,
-            mtime,
-            documents,
-        });
+        return Ok((
+            IndexedFile {
+                path: file_path,
+                mtime,
+                documents,
+            },
+            context_spans,
+        ));
     }
 
     fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
@@ -454,6 +538,9 @@ impl VectorStore {
                 .detach();
 
             this.update(&mut cx, |this, cx| {
+                // The below is managing for updated on save
+                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
+                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
                 let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
                     if let Some(project_state) = this.projects.get(&project.downgrade()) {
                         let worktree_db_ids = project_state.worktree_db_ids.clone();
@@ -554,8 +641,6 @@ impl VectorStore {
                 );
             });
 
-            log::info!("Semantic Indexing Complete!");
-
             anyhow::Ok(())
         })
     }
@@ -591,8 +676,6 @@ impl VectorStore {
             })
             .collect::<Vec<_>>();
 
-        log::info!("Searching for: {:?}", phrase);
-
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
         cx.spawn(|this, cx| async move {