Parallel vector db (#2792)

Kyle Caverly created

Parallelize Vector Database calls for project semantic search.

Release Notes: (Preview-only)

- Parallelize Vector database calls for project semantic search. Cuts
query time by 2/3rds.
- Removed default keymap for old semantic search modal.

Change summary

assets/keymaps/default.json                 |  1 
crates/semantic_index/src/db.rs             | 48 ++++++-------
crates/semantic_index/src/semantic_index.rs | 81 +++++++++++++++++-----
3 files changed, 84 insertions(+), 46 deletions(-)

Detailed changes

assets/keymaps/default.json 🔗

@@ -411,7 +411,6 @@
       "cmd-k cmd-t": "theme_selector::Toggle",
       "cmd-k cmd-s": "zed::OpenKeymap",
       "cmd-t": "project_symbols::Toggle",
-      "cmd-ctrl-t": "semantic_search::Toggle",
       "cmd-p": "file_finder::Toggle",
       "cmd-shift-p": "command_palette::Toggle",
       "cmd-shift-m": "diagnostics::Deploy",

crates/semantic_index/src/db.rs 🔗

@@ -267,41 +267,32 @@ impl VectorDatabase {
 
     pub fn top_k_search(
         &self,
-        worktree_ids: &[i64],
         query_embedding: &Vec<f32>,
         limit: usize,
-        include_globs: Vec<GlobMatcher>,
-        exclude_globs: Vec<GlobMatcher>,
-    ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
+        file_ids: &[i64],
+    ) -> Result<Vec<(i64, f32)>> {
         let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(
-            &worktree_ids,
-            include_globs,
-            exclude_globs,
-            |id, embedding| {
-                let similarity = dot(&embedding, &query_embedding);
-                let ix = match results.binary_search_by(|(_, s)| {
-                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                }) {
-                    Ok(ix) => ix,
-                    Err(ix) => ix,
-                };
-                results.insert(ix, (id, similarity));
-                results.truncate(limit);
-            },
-        )?;
+        self.for_each_document(file_ids, |id, embedding| {
+            let similarity = dot(&embedding, &query_embedding);
+            let ix = match results
+                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+            {
+                Ok(ix) => ix,
+                Err(ix) => ix,
+            };
+            results.insert(ix, (id, similarity));
+            results.truncate(limit);
+        })?;
 
-        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
-        self.get_documents_by_ids(&ids)
+        Ok(results)
     }
 
-    fn for_each_document(
+    pub fn retrieve_included_file_ids(
         &self,
         worktree_ids: &[i64],
         include_globs: Vec<GlobMatcher>,
         exclude_globs: Vec<GlobMatcher>,
-        mut f: impl FnMut(i64, Vec<f32>),
-    ) -> Result<()> {
+    ) -> Result<Vec<i64>> {
         let mut file_query = self.db.prepare(
             "
             SELECT
@@ -315,6 +306,7 @@ impl VectorDatabase {
 
         let mut file_ids = Vec::<i64>::new();
         let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
+
         while let Some(row) = rows.next()? {
             let file_id = row.get(0)?;
             let relative_path = row.get_ref(1)?.as_str()?;
@@ -330,6 +322,10 @@ impl VectorDatabase {
             }
         }
 
+        Ok(file_ids)
+    }
+
+    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
         let mut query_statement = self.db.prepare(
             "
             SELECT
@@ -350,7 +346,7 @@ impl VectorDatabase {
         Ok(())
     }
 
-    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
+    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut statement = self.db.prepare(
             "
                 SELECT

crates/semantic_index/src/semantic_index.rs 🔗

@@ -20,6 +20,7 @@ use postage::watch;
 use project::{Fs, Project, WorktreeId};
 use smol::channel;
 use std::{
+    cmp::Ordering,
     collections::HashMap,
     mem,
     ops::Range,
@@ -704,27 +705,69 @@ impl SemanticIndex {
         let database_url = self.database_url.clone();
         let fs = self.fs.clone();
         cx.spawn(|this, mut cx| async move {
-            let documents = cx
-                .background()
-                .spawn(async move {
-                    let database = VectorDatabase::new(fs, database_url).await?;
+            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
 
-                    let phrase_embedding = embedding_provider
-                        .embed_batch(vec![&phrase])
-                        .await?
-                        .into_iter()
-                        .next()
-                        .unwrap();
+            let phrase_embedding = embedding_provider
+                .embed_batch(vec![&phrase])
+                .await?
+                .into_iter()
+                .next()
+                .unwrap();
 
-                    database.top_k_search(
-                        &worktree_db_ids,
-                        &phrase_embedding,
-                        limit,
-                        include_globs,
-                        exclude_globs,
-                    )
-                })
-                .await?;
+            let file_ids = database.retrieve_included_file_ids(
+                &worktree_db_ids,
+                include_globs,
+                exclude_globs,
+            )?;
+
+            let batch_n = cx.background().num_cpus();
+            let ids_len = file_ids.clone().len();
+            let batch_size = if ids_len <= batch_n {
+                ids_len
+            } else {
+                ids_len / batch_n
+            };
+
+            let mut result_tasks = Vec::new();
+            for batch in file_ids.chunks(batch_size) {
+                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
+                let limit = limit.clone();
+                let fs = fs.clone();
+                let database_url = database_url.clone();
+                let phrase_embedding = phrase_embedding.clone();
+                let task = cx.background().spawn(async move {
+                    let database = VectorDatabase::new(fs, database_url).await.log_err();
+                    if database.is_none() {
+                        return Err(anyhow!("failed to acquire database connection"));
+                    } else {
+                        database
+                            .unwrap()
+                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
+                    }
+                });
+                result_tasks.push(task);
+            }
+
+            let batch_results = futures::future::join_all(result_tasks).await;
+
+            let mut results = Vec::new();
+            for batch_result in batch_results {
+                if batch_result.is_ok() {
+                    for (id, similarity) in batch_result.unwrap() {
+                        let ix = match results.binary_search_by(|(_, s)| {
+                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                        }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+                        results.insert(ix, (id, similarity));
+                        results.truncate(limit);
+                    }
+                }
+            }
+
+            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
+            let documents = database.get_documents_by_ids(ids.as_slice())?;
 
             let mut tasks = Vec::new();
             let mut ranges = Vec::new();